logit_bias: handle comma seperated values

This commit is contained in:
dannyl1u 2025-02-27 23:13:30 -08:00
parent 34e3cb6881
commit 90aa29528c
2 changed files with 16 additions and 3 deletions

View File

@ -68,6 +68,7 @@ from open_webui.utils.misc import (
get_last_user_message, get_last_user_message,
get_last_assistant_message, get_last_assistant_message,
prepend_to_first_user_message_content, prepend_to_first_user_message_content,
convert_logit_bias_input_to_json
) )
from open_webui.utils.tools import get_tools from open_webui.utils.tools import get_tools
from open_webui.utils.plugin import load_function_module_by_id from open_webui.utils.plugin import load_function_module_by_id
@ -593,9 +594,9 @@ def apply_params_to_form_data(form_data, model):
form_data["reasoning_effort"] = params["reasoning_effort"] form_data["reasoning_effort"] = params["reasoning_effort"]
if "logit_bias" in params: if "logit_bias" in params:
try: try:
form_data["logit_bias"] = json.loads(params["logit_bias"]) form_data["logit_bias"] = json.loads(convert_logit_bias_input_to_json(params["logit_bias"]))
except json.JSONDecodeError: except Exception as e:
print("Invalid JSON format for logit_bias") print(f"Error parsing logit_bias: {e}")
return form_data return form_data

View File

@ -5,6 +5,7 @@ import uuid
from datetime import timedelta from datetime import timedelta
from pathlib import Path from pathlib import Path
from typing import Callable, Optional from typing import Callable, Optional
import json
import collections.abc import collections.abc
@ -445,3 +446,14 @@ def parse_ollama_modelfile(model_text):
data["params"]["messages"] = messages data["params"]["messages"] = messages
return data return data
def convert_logit_bias_input_to_json(user_input):
logit_bias_pairs = user_input.split(',')
logit_bias_json = {}
for pair in logit_bias_pairs:
token, bias = pair.split(':')
token = str(token.strip())
bias = int(bias.strip())
bias = 100 if bias > 100 else -100 if bias < -100 else bias
logit_bias_json[token] = bias
return json.dumps(logit_bias_json)