refac: max_tokens -> max_completion_tokens

This commit is contained in:
Timothy J. Baek 2024-09-19 17:19:31 +02:00
parent 60d6279055
commit f8fffdd288

View File

@ -1398,9 +1398,9 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)):
# Check if the user has a custom task model
# If the user has a custom task model, use that model
model_id = get_task_model_id(model_id)
task_model_id = get_task_model_id(model_id)
print(model_id)
print(task_model_id)
if app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE != "":
template = app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE
@ -1427,10 +1427,16 @@ Prompt: {{prompt:middletruncate:8000}}"""
)
payload = {
"model": model_id,
"model": task_model_id,
"messages": [{"role": "user", "content": content}],
"stream": False,
"max_tokens": 50,
**(
{"max_tokens": 50}
if app.state.MODELS[task_model_id]["owned_by"] == "ollama"
else {
"max_completion_tokens": 50,
}
),
"chat_id": form_data.get("chat_id", None),
"metadata": {"task": str(TASKS.TITLE_GENERATION)},
}
@ -1475,9 +1481,8 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user)
# Check if the user has a custom task model
# If the user has a custom task model, use that model
model_id = get_task_model_id(model_id)
print(model_id)
task_model_id = get_task_model_id(model_id)
print(task_model_id)
if app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE != "":
template = app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
@ -1499,10 +1504,16 @@ Search Query:"""
print("content", content)
payload = {
"model": model_id,
"model": task_model_id,
"messages": [{"role": "user", "content": content}],
"stream": False,
"max_tokens": 30,
**(
{"max_tokens": 30}
if app.state.MODELS[task_model_id]["owned_by"] == "ollama"
else {
"max_completion_tokens": 30,
}
),
"metadata": {"task": str(TASKS.QUERY_GENERATION)},
}
@ -1541,9 +1552,8 @@ async def generate_emoji(form_data: dict, user=Depends(get_verified_user)):
# Check if the user has a custom task model
# If the user has a custom task model, use that model
model_id = get_task_model_id(model_id)
print(model_id)
task_model_id = get_task_model_id(model_id)
print(task_model_id)
template = '''
Your task is to reflect the speaker's likely facial expression through a fitting emoji. Interpret emotions from the message and reflect their facial expression using fitting, diverse emojis (e.g., 😊, 😢, 😡, 😱).
@ -1561,10 +1571,16 @@ Message: """{{prompt}}"""
)
payload = {
"model": model_id,
"model": task_model_id,
"messages": [{"role": "user", "content": content}],
"stream": False,
"max_tokens": 4,
**(
{"max_tokens": 4}
if app.state.MODELS[task_model_id]["owned_by"] == "ollama"
else {
"max_completion_tokens": 4,
}
),
"chat_id": form_data.get("chat_id", None),
"metadata": {"task": str(TASKS.EMOJI_GENERATION)},
}