fix: name differences

This commit is contained in:
Michael Poluektov 2024-08-08 11:01:00 +01:00
parent e6bbce439d
commit 8cdf9814bd

View File

@ -148,23 +148,24 @@ def apply_model_params_to_body(
return form_data return form_data
OPENAI_MAPPINGS = {
"temperature": float,
"top_p": int,
"max_tokens": int,
"frequency_penalty": int,
"seed": lambda x: x,
"stop": lambda x: [bytes(s, "utf-8").decode("unicode_escape") for s in x],
}
# inplace function: form_data is modified # inplace function: form_data is modified
def apply_model_params_to_body_openai(params: dict, form_data: dict) -> dict: def apply_model_params_to_body_openai(params: dict, form_data: dict) -> dict:
return apply_model_params_to_body(params, form_data, OPENAI_MAPPINGS) mappings = {
"temperature": float,
"top_p": int,
"max_tokens": int,
"frequency_penalty": int,
"seed": lambda x: x,
"stop": lambda x: [bytes(s, "utf-8").decode("unicode_escape") for s in x],
}
return apply_model_params_to_body(params, form_data, mappings)
def apply_model_params_to_body_ollama(params: dict, form_data: dict) -> dict: def apply_model_params_to_body_ollama(params: dict, form_data: dict) -> dict:
opts = [ opts = [
"temperature",
"top_p",
"seed",
"mirostat", "mirostat",
"mirostat_eta", "mirostat_eta",
"mirostat_tau", "mirostat_tau",
@ -180,12 +181,18 @@ def apply_model_params_to_body_ollama(params: dict, form_data: dict) -> dict:
"num_thread", "num_thread",
] ]
mappings = {i: lambda x: x for i in opts} mappings = {i: lambda x: x for i in opts}
mappings = {**mappings, **OPENAI_MAPPINGS}
form_data = apply_model_params_to_body(params, form_data, mappings) form_data = apply_model_params_to_body(params, form_data, mappings)
# only param that changes name name_differences = {
if (param := params.get("frequency_penalty", None)) is not None: "max_tokens": "num_predict",
form_data["repeat_penalty"] = param "frequency_penalty": "repeat_penalty",
}
for key, value in name_differences.items():
if (param := params.get(key, None)) is not None:
form_data[value] = param
print(form_data)
return form_data return form_data