logit_bias: handle comma seperated values

This commit is contained in:
dannyl1u 2025-02-27 23:13:30 -08:00
parent 34e3cb6881
commit 90aa29528c
2 changed files with 16 additions and 3 deletions

View File

@ -68,6 +68,7 @@ from open_webui.utils.misc import (
get_last_user_message,
get_last_assistant_message,
prepend_to_first_user_message_content,
convert_logit_bias_input_to_json
)
from open_webui.utils.tools import get_tools
from open_webui.utils.plugin import load_function_module_by_id
@ -593,9 +594,9 @@ def apply_params_to_form_data(form_data, model):
form_data["reasoning_effort"] = params["reasoning_effort"]
if "logit_bias" in params:
try:
form_data["logit_bias"] = json.loads(params["logit_bias"])
except json.JSONDecodeError:
print("Invalid JSON format for logit_bias")
form_data["logit_bias"] = json.loads(convert_logit_bias_input_to_json(params["logit_bias"]))
except Exception as e:
print(f"Error parsing logit_bias: {e}")
return form_data

View File

@ -5,6 +5,7 @@ import uuid
from datetime import timedelta
from pathlib import Path
from typing import Callable, Optional
import json
import collections.abc
@ -445,3 +446,14 @@ def parse_ollama_modelfile(model_text):
data["params"]["messages"] = messages
return data
def convert_logit_bias_input_to_json(user_input):
logit_bias_pairs = user_input.split(',')
logit_bias_json = {}
for pair in logit_bias_pairs:
token, bias = pair.split(':')
token = str(token.strip())
bias = int(bias.strip())
bias = 100 if bias > 100 else -100 if bias < -100 else bias
logit_bias_json[token] = bias
return json.dumps(logit_bias_json)