From 4a2a12fd21c15f1671f82315b7897a0f9af4f6f3 Mon Sep 17 00:00:00 2001 From: dannyl1u Date: Wed, 19 Feb 2025 10:33:49 -0800 Subject: [PATCH 1/4] feat: scaffolding for logit_bias --- backend/open_webui/utils/payload.py | 1 + .../Settings/Advanced/AdvancedParams.svelte | 44 +++++++++++++++++++ .../components/chat/Settings/General.svelte | 3 ++ 3 files changed, 48 insertions(+) diff --git a/backend/open_webui/utils/payload.py b/backend/open_webui/utils/payload.py index 5eb040434..d078362ee 100644 --- a/backend/open_webui/utils/payload.py +++ b/backend/open_webui/utils/payload.py @@ -61,6 +61,7 @@ def apply_model_params_to_body_openai(params: dict, form_data: dict) -> dict: "reasoning_effort": str, "seed": lambda x: x, "stop": lambda x: [bytes(s, "utf-8").decode("unicode_escape") for s in x], + "logit_bias": lambda x: x, } return apply_model_params_to_body(params, form_data, mappings) diff --git a/src/lib/components/chat/Settings/Advanced/AdvancedParams.svelte b/src/lib/components/chat/Settings/Advanced/AdvancedParams.svelte index d0648bba5..5e53b999e 100644 --- a/src/lib/components/chat/Settings/Advanced/AdvancedParams.svelte +++ b/src/lib/components/chat/Settings/Advanced/AdvancedParams.svelte @@ -17,6 +17,7 @@ stop: null, temperature: null, reasoning_effort: null, + logit_bias: null, frequency_penalty: null, repeat_last_n: null, mirostat: null, @@ -298,6 +299,49 @@ {/if} +
+ +
+
+ {$i18n.t('Logit Bias')} +
+ +
+
+ + {#if (params?.logit_bias ?? null) !== null} +
+
+ +
+
+ {/if} +
+
Date: Wed, 19 Feb 2025 16:23:58 -0800 Subject: [PATCH 2/4] include logit_bias in form_data --- backend/open_webui/utils/middleware.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index ba55c095e..359ef775c 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -591,6 +591,11 @@ def apply_params_to_form_data(form_data, model): if "reasoning_effort" in params: form_data["reasoning_effort"] = params["reasoning_effort"] + if "logit_bias" in params: + try: + form_data["logit_bias"] = json.loads(params["logit_bias"]) + except json.JSONDecodeError: + print("Invalid JSON format for logit_bias") return form_data From 34e3cb688147a1143963a120017d6e98292ce121 Mon Sep 17 00:00:00 2001 From: dannyl1u Date: Thu, 27 Feb 2025 23:13:09 -0800 Subject: [PATCH 3/4] logit bias: update tooltip message --- .../components/chat/Settings/Advanced/AdvancedParams.svelte | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/lib/components/chat/Settings/Advanced/AdvancedParams.svelte b/src/lib/components/chat/Settings/Advanced/AdvancedParams.svelte index 5e53b999e..8d59fdef2 100644 --- a/src/lib/components/chat/Settings/Advanced/AdvancedParams.svelte +++ b/src/lib/components/chat/Settings/Advanced/AdvancedParams.svelte @@ -302,7 +302,7 @@
From 90aa29528c0360dc0a93a4c9c6ccc64d5d1b3102 Mon Sep 17 00:00:00 2001 From: dannyl1u Date: Thu, 27 Feb 2025 23:13:30 -0800 Subject: [PATCH 4/4] logit_bias: handle comma seperated values --- backend/open_webui/utils/middleware.py | 7 ++++--- backend/open_webui/utils/misc.py | 12 ++++++++++++ 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index 359ef775c..d52be2487 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -68,6 +68,7 @@ from open_webui.utils.misc import ( get_last_user_message, get_last_assistant_message, prepend_to_first_user_message_content, + convert_logit_bias_input_to_json ) from open_webui.utils.tools import get_tools from open_webui.utils.plugin import load_function_module_by_id @@ -593,9 +594,9 @@ def apply_params_to_form_data(form_data, model): form_data["reasoning_effort"] = params["reasoning_effort"] if "logit_bias" in params: try: - form_data["logit_bias"] = json.loads(params["logit_bias"]) - except json.JSONDecodeError: - print("Invalid JSON format for logit_bias") + form_data["logit_bias"] = json.loads(convert_logit_bias_input_to_json(params["logit_bias"])) + except Exception as e: + print(f"Error parsing logit_bias: {e}") return form_data diff --git a/backend/open_webui/utils/misc.py b/backend/open_webui/utils/misc.py index f79b62684..8ab743316 100644 --- a/backend/open_webui/utils/misc.py +++ b/backend/open_webui/utils/misc.py @@ -5,6 +5,7 @@ import uuid from datetime import timedelta from pathlib import Path from typing import Callable, Optional +import json import collections.abc @@ -445,3 +446,14 @@ def parse_ollama_modelfile(model_text): data["params"]["messages"] = messages return data + +def convert_logit_bias_input_to_json(user_input): + logit_bias_pairs = user_input.split(',') + logit_bias_json = {} + for pair in logit_bias_pairs: + token, bias = pair.split(':') + token = str(token.strip()) + bias = int(bias.strip()) + bias = 100 if bias > 100 else -100 if bias < -100 else bias + logit_bias_json[token] = bias + return json.dumps(logit_bias_json) \ No newline at end of file