diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index 43fd0d480..de2b9c468 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -68,6 +68,7 @@ from open_webui.utils.misc import ( get_last_user_message, get_last_assistant_message, prepend_to_first_user_message_content, + convert_logit_bias_input_to_json ) from open_webui.utils.tools import get_tools from open_webui.utils.plugin import load_function_module_by_id @@ -610,6 +611,11 @@ def apply_params_to_form_data(form_data, model): if "reasoning_effort" in params: form_data["reasoning_effort"] = params["reasoning_effort"] + if "logit_bias" in params: + try: + form_data["logit_bias"] = json.loads(convert_logit_bias_input_to_json(params["logit_bias"])) + except Exception as e: + print(f"Error parsing logit_bias: {e}") return form_data diff --git a/backend/open_webui/utils/misc.py b/backend/open_webui/utils/misc.py index 8f867bace..85e7f6415 100644 --- a/backend/open_webui/utils/misc.py +++ b/backend/open_webui/utils/misc.py @@ -6,6 +6,7 @@ import logging from datetime import timedelta from pathlib import Path from typing import Callable, Optional +import json import collections.abc @@ -450,3 +451,14 @@ def parse_ollama_modelfile(model_text): data["params"]["messages"] = messages return data + +def convert_logit_bias_input_to_json(user_input): + logit_bias_pairs = user_input.split(',') + logit_bias_json = {} + for pair in logit_bias_pairs: + token, bias = pair.split(':') + token = str(token.strip()) + bias = int(bias.strip()) + bias = 100 if bias > 100 else -100 if bias < -100 else bias + logit_bias_json[token] = bias + return json.dumps(logit_bias_json) \ No newline at end of file diff --git a/backend/open_webui/utils/payload.py b/backend/open_webui/utils/payload.py index 869e70895..46656cc82 100644 --- a/backend/open_webui/utils/payload.py +++ b/backend/open_webui/utils/payload.py @@ -62,6 +62,7 @@ def apply_model_params_to_body_openai(params: dict, form_data: dict) -> dict: "reasoning_effort": str, "seed": lambda x: x, "stop": lambda x: [bytes(s, "utf-8").decode("unicode_escape") for s in x], + "logit_bias": lambda x: x, } return apply_model_params_to_body(params, form_data, mappings) diff --git a/src/lib/components/chat/Settings/Advanced/AdvancedParams.svelte b/src/lib/components/chat/Settings/Advanced/AdvancedParams.svelte index f1b8e8e52..5b10230ed 100644 --- a/src/lib/components/chat/Settings/Advanced/AdvancedParams.svelte +++ b/src/lib/components/chat/Settings/Advanced/AdvancedParams.svelte @@ -17,6 +17,7 @@ stop: null, temperature: null, reasoning_effort: null, + logit_bias: null, frequency_penalty: null, repeat_last_n: null, mirostat: null, @@ -298,6 +299,49 @@ {/if} +