From deec41d29a850afa62f1607da231ad25aee88cc2 Mon Sep 17 00:00:00 2001 From: Michael Poluektov Date: Wed, 31 Jul 2024 13:51:25 +0100 Subject: [PATCH] fix: function early returns --- backend/apps/webui/main.py | 110 +++++++++++++++++++------------------ 1 file changed, 57 insertions(+), 53 deletions(-) diff --git a/backend/apps/webui/main.py b/backend/apps/webui/main.py index 331713b07..13761f8cb 100644 --- a/backend/apps/webui/main.py +++ b/backend/apps/webui/main.py @@ -291,12 +291,7 @@ def get_params_dict(pipe, form_data, user, extra_params, function_module): return params -async def generate_function_chat_completion(form_data, user): - model_id = form_data.get("model") - model_info = Models.get_model_by_id(model_id) - - metadata = form_data.pop("metadata", None) - +def get_extra_params(metadata: dict): __event_emitter__ = __event_call__ = __task__ = None if metadata: @@ -305,58 +300,67 @@ async def generate_function_chat_completion(form_data, user): __event_call__ = get_event_call(metadata) __task__ = metadata.get("task", None) - if not model_info: - return - - if model_info.base_model_id: - form_data["model"] = model_info.base_model_id - - params = model_info.params.model_dump() - - if params: - mappings = { - "temperature": float, - "top_p": int, - "max_tokens": int, - "frequency_penalty": int, - "seed": lambda x: x, - "stop": lambda x: [bytes(s, "utf-8").decode("unicode_escape") for s in x], - } - - for key, cast_func in mappings.items(): - if (value := params.get(key)) is not None: - form_data[key] = cast_func(value) - - system = params.get("system", None) - if not system: - return - - if user: - template_params = { - "user_name": user.name, - "user_location": user.info.get("location") if user.info else None, - } - else: - template_params = {} - - system = prompt_template(system, **template_params) - - # Check if the payload already has a system message - # If not, add a system message to the payload - for message in form_data.get("messages", []): - if message.get("role") == "system": - message["content"] = system + message["content"] - break - else: - if form_data.get("messages"): - form_data["messages"].insert(0, {"role": "system", "content": system}) - - extra_params = { + return { "__event_emitter__": __event_emitter__, "__event_call__": __event_call__, "__task__": __task__, } + +async def generate_function_chat_completion(form_data, user): + print("entry point") + model_id = form_data.get("model") + model_info = Models.get_model_by_id(model_id) + + metadata = form_data.pop("metadata", None) + extra_params = get_extra_params(metadata) + + if model_info: + if model_info.base_model_id: + form_data["model"] = model_info.base_model_id + + params = model_info.params.model_dump() + + if params: + mappings = { + "temperature": float, + "top_p": int, + "max_tokens": int, + "frequency_penalty": int, + "seed": lambda x: x, + "stop": lambda x: [ + bytes(s, "utf-8").decode("unicode_escape") for s in x + ], + } + + for key, cast_func in mappings.items(): + if (value := params.get(key)) is not None: + form_data[key] = cast_func(value) + + system = params.get("system", None) + if system: + if user: + template_params = { + "user_name": user.name, + "user_location": user.info.get("location") if user.info else None, + } + else: + template_params = {} + + system = prompt_template(system, **template_params) + + # Check if the payload already has a system message + # If not, add a system message to the payload + for message in form_data.get("messages", []): + if message.get("role") == "system": + message["content"] = system + message["content"] + break + else: + if form_data.get("messages"): + form_data["messages"].insert( + 0, {"role": "system", "content": system} + ) + async def job(): pipe_id = get_pipe_id(form_data) function_module = get_function_module(pipe_id)