diff --git a/backend/apps/webui/main.py b/backend/apps/webui/main.py index 51fa711ca..e64995af5 100644 --- a/backend/apps/webui/main.py +++ b/backend/apps/webui/main.py @@ -287,6 +287,26 @@ def get_extra_params(metadata: dict): } +def add_model_params(params: dict, form_data: dict) -> dict: + if not params: + return form_data + + mappings = { + "temperature": float, + "top_p": int, + "max_tokens": int, + "frequency_penalty": int, + "seed": lambda x: x, + "stop": lambda x: [bytes(s, "utf-8").decode("unicode_escape") for s in x], + } + + for key, cast_func in mappings.items(): + if (value := params.get(key)) is not None: + form_data[key] = cast_func(value) + + return form_data + + async def generate_function_chat_completion(form_data, user): print("entry point") model_id = form_data.get("model") @@ -300,24 +320,9 @@ async def generate_function_chat_completion(form_data, user): form_data["model"] = model_info.base_model_id params = model_info.params.model_dump() - - if params: - mappings = { - "temperature": float, - "top_p": int, - "max_tokens": int, - "frequency_penalty": int, - "seed": lambda x: x, - "stop": lambda x: [ - bytes(s, "utf-8").decode("unicode_escape") for s in x - ], - } - - for key, cast_func in mappings.items(): - if (value := params.get(key)) is not None: - form_data[key] = cast_func(value) - system = params.get("system", None) + form_data = add_model_params(params, form_data) + if system: if user: template_params = { @@ -381,7 +386,7 @@ async def generate_function_chat_completion(form_data, user): yield process_line(form_data, line) if isinstance(res, str) or isinstance(res, Generator): - finish_message = stream_message_template(form_data, "") + finish_message = stream_message_template(form_data["model"], "") finish_message["choices"][0]["finish_reason"] = "stop" yield f"data: {json.dumps(finish_message)}\n\n" yield "data: [DONE]"