From 90aa29528c0360dc0a93a4c9c6ccc64d5d1b3102 Mon Sep 17 00:00:00 2001 From: dannyl1u Date: Thu, 27 Feb 2025 23:13:30 -0800 Subject: [PATCH] logit_bias: handle comma seperated values --- backend/open_webui/utils/middleware.py | 7 ++++--- backend/open_webui/utils/misc.py | 12 ++++++++++++ 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index 359ef775c..d52be2487 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -68,6 +68,7 @@ from open_webui.utils.misc import ( get_last_user_message, get_last_assistant_message, prepend_to_first_user_message_content, + convert_logit_bias_input_to_json ) from open_webui.utils.tools import get_tools from open_webui.utils.plugin import load_function_module_by_id @@ -593,9 +594,9 @@ def apply_params_to_form_data(form_data, model): form_data["reasoning_effort"] = params["reasoning_effort"] if "logit_bias" in params: try: - form_data["logit_bias"] = json.loads(params["logit_bias"]) - except json.JSONDecodeError: - print("Invalid JSON format for logit_bias") + form_data["logit_bias"] = json.loads(convert_logit_bias_input_to_json(params["logit_bias"])) + except Exception as e: + print(f"Error parsing logit_bias: {e}") return form_data diff --git a/backend/open_webui/utils/misc.py b/backend/open_webui/utils/misc.py index f79b62684..8ab743316 100644 --- a/backend/open_webui/utils/misc.py +++ b/backend/open_webui/utils/misc.py @@ -5,6 +5,7 @@ import uuid from datetime import timedelta from pathlib import Path from typing import Callable, Optional +import json import collections.abc @@ -445,3 +446,14 @@ def parse_ollama_modelfile(model_text): data["params"]["messages"] = messages return data + +def convert_logit_bias_input_to_json(user_input): + logit_bias_pairs = user_input.split(',') + logit_bias_json = {} + for pair in logit_bias_pairs: + token, bias = pair.split(':') + token = str(token.strip()) + bias = int(bias.strip()) + bias = 100 if bias > 100 else -100 if bias < -100 else bias + logit_bias_json[token] = bias + return json.dumps(logit_bias_json) \ No newline at end of file