refac: re-use utils.misc

This commit is contained in:
Michael Poluektov
2024-08-06 11:31:45 +01:00
parent 44c781f414
commit fc31267a54
4 changed files with 85 additions and 229 deletions

View File

@@ -2,7 +2,7 @@ from pathlib import Path
import hashlib
import re
from datetime import timedelta
from typing import Optional, List, Tuple
from typing import Optional, List, Tuple, Callable
import uuid
import time
@@ -135,19 +135,12 @@ def apply_model_system_prompt_to_body(params: dict, form_data: dict, user) -> di
# inplace function: form_data is modified
def apply_model_params_to_body(params: dict, form_data: dict) -> dict:
def apply_model_params_to_body(
params: dict, form_data: dict, mappings: dict[str, Callable]
) -> dict:
if not params:
return form_data
mappings = {
"temperature": float,
"top_p": int,
"max_tokens": int,
"frequency_penalty": int,
"seed": lambda x: x,
"stop": lambda x: [bytes(s, "utf-8").decode("unicode_escape") for s in x],
}
for key, cast_func in mappings.items():
if (value := params.get(key)) is not None:
form_data[key] = cast_func(value)
@@ -155,6 +148,42 @@ def apply_model_params_to_body(params: dict, form_data: dict) -> dict:
return form_data
OPENAI_MAPPINGS = {
"temperature": float,
"top_p": int,
"max_tokens": int,
"frequency_penalty": int,
"seed": lambda x: x,
"stop": lambda x: [bytes(s, "utf-8").decode("unicode_escape") for s in x],
}
# inplace function: form_data is modified
def apply_model_params_to_body_openai(params: dict, form_data: dict) -> dict:
return apply_model_params_to_body(params, form_data, OPENAI_MAPPINGS)
def apply_model_params_to_body_ollama(params: dict, form_data: dict) -> dict:
opts = [
"mirostat",
"mirostat_eta",
"mirostat_tau",
"num_ctx",
"num_batch",
"num_keep",
"repeat_last_n",
"tfs_z",
"top_k",
"min_p",
"use_mmap",
"use_mlock",
"num_thread",
]
mappings = {i: lambda x: x for i in opts}
mappings = {**mappings, **OPENAI_MAPPINGS}
return apply_model_params_to_body(params, form_data, mappings)
def get_gravatar_url(email):
# Trim leading and trailing whitespace from
# an email address and force all characters