Merge branch 'dev' into fix/ollama-cancellation

This commit is contained in:
Timothy Jaeryang Baek
2024-06-02 16:27:01 -07:00
committed by GitHub
47 changed files with 999 additions and 565 deletions

View File

@@ -734,44 +734,77 @@ async def generate_chat_completion(
if model_info.params:
payload["options"] = {}
payload["options"]["mirostat"] = model_info.params.get("mirostat", None)
payload["options"]["mirostat_eta"] = model_info.params.get(
"mirostat_eta", None
)
payload["options"]["mirostat_tau"] = model_info.params.get(
"mirostat_tau", None
)
payload["options"]["num_ctx"] = model_info.params.get("num_ctx", None)
if model_info.params.get("mirostat", None):
payload["options"]["mirostat"] = model_info.params.get("mirostat", None)
payload["options"]["repeat_last_n"] = model_info.params.get(
"repeat_last_n", None
)
payload["options"]["repeat_penalty"] = model_info.params.get(
"frequency_penalty", None
)
if model_info.params.get("mirostat_eta", None):
payload["options"]["mirostat_eta"] = model_info.params.get(
"mirostat_eta", None
)
payload["options"]["temperature"] = model_info.params.get(
"temperature", None
)
payload["options"]["seed"] = model_info.params.get("seed", None)
if model_info.params.get("mirostat_tau", None):
payload["options"]["stop"] = (
[
bytes(stop, "utf-8").decode("unicode_escape")
for stop in model_info.params["stop"]
]
if model_info.params.get("stop", None)
else None
)
payload["options"]["mirostat_tau"] = model_info.params.get(
"mirostat_tau", None
)
payload["options"]["tfs_z"] = model_info.params.get("tfs_z", None)
if model_info.params.get("num_ctx", None):
payload["options"]["num_ctx"] = model_info.params.get("num_ctx", None)
payload["options"]["num_predict"] = model_info.params.get(
"max_tokens", None
)
payload["options"]["top_k"] = model_info.params.get("top_k", None)
if model_info.params.get("repeat_last_n", None):
payload["options"]["repeat_last_n"] = model_info.params.get(
"repeat_last_n", None
)
payload["options"]["top_p"] = model_info.params.get("top_p", None)
if model_info.params.get("frequency_penalty", None):
payload["options"]["repeat_penalty"] = model_info.params.get(
"frequency_penalty", None
)
if model_info.params.get("temperature", None):
payload["options"]["temperature"] = model_info.params.get(
"temperature", None
)
if model_info.params.get("seed", None):
payload["options"]["seed"] = model_info.params.get("seed", None)
if model_info.params.get("stop", None):
payload["options"]["stop"] = (
[
bytes(stop, "utf-8").decode("unicode_escape")
for stop in model_info.params["stop"]
]
if model_info.params.get("stop", None)
else None
)
if model_info.params.get("tfs_z", None):
payload["options"]["tfs_z"] = model_info.params.get("tfs_z", None)
if model_info.params.get("max_tokens", None):
payload["options"]["num_predict"] = model_info.params.get(
"max_tokens", None
)
if model_info.params.get("top_k", None):
payload["options"]["top_k"] = model_info.params.get("top_k", None)
if model_info.params.get("top_p", None):
payload["options"]["top_p"] = model_info.params.get("top_p", None)
if model_info.params.get("use_mmap", None):
payload["options"]["use_mmap"] = model_info.params.get("use_mmap", None)
if model_info.params.get("use_mlock", None):
payload["options"]["use_mlock"] = model_info.params.get(
"use_mlock", None
)
if model_info.params.get("num_thread", None):
payload["options"]["num_thread"] = model_info.params.get(
"num_thread", None
)
if model_info.params.get("system", None):
# Check if the payload already has a system message

View File

@@ -239,6 +239,27 @@ async def get_all_models(raw: bool = False):
) or not app.state.config.ENABLE_OPENAI_API:
models = {"data": []}
else:
# Check if API KEYS length is same than API URLS length
if len(app.state.config.OPENAI_API_KEYS) != len(
app.state.config.OPENAI_API_BASE_URLS
):
# if there are more keys than urls, remove the extra keys
if len(app.state.config.OPENAI_API_KEYS) > len(
app.state.config.OPENAI_API_BASE_URLS
):
app.state.config.OPENAI_API_KEYS = app.state.config.OPENAI_API_KEYS[
: len(app.state.config.OPENAI_API_BASE_URLS)
]
# if there are more urls than keys, add empty keys
else:
app.state.config.OPENAI_API_KEYS += [
""
for _ in range(
len(app.state.config.OPENAI_API_BASE_URLS)
- len(app.state.config.OPENAI_API_KEYS)
)
]
tasks = [
fetch_url(f"{url}/models", app.state.config.OPENAI_API_KEYS[idx])
for idx, url in enumerate(app.state.config.OPENAI_API_BASE_URLS)