Merge pull request #10373 from dannyl1u/logit_bias

feat: logit bias
This commit is contained in:
Timothy Jaeryang Baek
2025-03-01 06:13:19 -08:00
committed by GitHub
5 changed files with 66 additions and 0 deletions

View File

@@ -68,6 +68,7 @@ from open_webui.utils.misc import (
get_last_user_message,
get_last_assistant_message,
prepend_to_first_user_message_content,
convert_logit_bias_input_to_json
)
from open_webui.utils.tools import get_tools
from open_webui.utils.plugin import load_function_module_by_id
@@ -610,6 +611,11 @@ def apply_params_to_form_data(form_data, model):
if "reasoning_effort" in params:
form_data["reasoning_effort"] = params["reasoning_effort"]
if "logit_bias" in params:
try:
form_data["logit_bias"] = json.loads(convert_logit_bias_input_to_json(params["logit_bias"]))
except Exception as e:
print(f"Error parsing logit_bias: {e}")
return form_data

View File

@@ -6,6 +6,7 @@ import logging
from datetime import timedelta
from pathlib import Path
from typing import Callable, Optional
import json
import collections.abc
@@ -450,3 +451,14 @@ def parse_ollama_modelfile(model_text):
data["params"]["messages"] = messages
return data
def convert_logit_bias_input_to_json(user_input):
logit_bias_pairs = user_input.split(',')
logit_bias_json = {}
for pair in logit_bias_pairs:
token, bias = pair.split(':')
token = str(token.strip())
bias = int(bias.strip())
bias = 100 if bias > 100 else -100 if bias < -100 else bias
logit_bias_json[token] = bias
return json.dumps(logit_bias_json)

View File

@@ -62,6 +62,7 @@ def apply_model_params_to_body_openai(params: dict, form_data: dict) -> dict:
"reasoning_effort": str,
"seed": lambda x: x,
"stop": lambda x: [bytes(s, "utf-8").decode("unicode_escape") for s in x],
"logit_bias": lambda x: x,
}
return apply_model_params_to_body(params, form_data, mappings)