fix/refac: use ollama /api/chat endpoint for tasks

This commit is contained in:
Timothy J. Baek 2024-09-21 00:30:13 +02:00
parent 585b9eb84a
commit 41926172d3
3 changed files with 111 additions and 21 deletions

View File

@ -19,7 +19,9 @@ from open_webui.apps.audio.main import app as audio_app
from open_webui.apps.images.main import app as images_app
from open_webui.apps.ollama.main import app as ollama_app
from open_webui.apps.ollama.main import (
generate_openai_chat_completion as generate_ollama_chat_completion,
GenerateChatCompletionForm,
generate_chat_completion as generate_ollama_chat_completion,
generate_openai_chat_completion as generate_ollama_openai_chat_completion,
)
from open_webui.apps.ollama.main import get_all_models as get_ollama_models
from open_webui.apps.openai.main import app as openai_app
@ -135,6 +137,9 @@ from open_webui.utils.utils import (
)
from open_webui.utils.webhook import post_webhook
from open_webui.utils.payload import convert_payload_openai_to_ollama
from open_webui.utils.response import convert_response_ollama_to_openai
if SAFE_MODE:
print("SAFE MODE ENABLED")
Functions.deactivate_all_functions()
@ -1048,7 +1053,7 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
if model.get("pipe"):
return await generate_function_chat_completion(form_data, user=user)
if model["owned_by"] == "ollama":
return await generate_ollama_chat_completion(form_data, user=user)
return await generate_ollama_openai_chat_completion(form_data, user=user)
else:
return await generate_openai_chat_completion(form_data, user=user)
@ -1399,9 +1404,10 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)):
# Check if the user has a custom task model
# If the user has a custom task model, use that model
task_model_id = get_task_model_id(model_id)
print(task_model_id)
model = app.state.MODELS[task_model_id]
if app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE != "":
template = app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE
else:
@ -1440,9 +1446,9 @@ Prompt: {{prompt:middletruncate:8000}}"""
"chat_id": form_data.get("chat_id", None),
"metadata": {"task": str(TASKS.TITLE_GENERATION)},
}
log.debug(payload)
# Handle pipeline filters
try:
payload = filter_pipeline(payload, user)
except Exception as e:
@ -1456,11 +1462,17 @@ Prompt: {{prompt:middletruncate:8000}}"""
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": str(e)},
)
if "chat_id" in payload:
del payload["chat_id"]
return await generate_chat_completions(form_data=payload, user=user)
# Check if task model is ollama model
if model["owned_by"] == "ollama":
payload = convert_payload_openai_to_ollama(payload)
form_data = GenerateChatCompletionForm(**payload)
response = await generate_ollama_chat_completion(form_data=form_data, user=user)
return convert_response_ollama_to_openai(response)
else:
return await generate_chat_completions(form_data=payload, user=user)
@app.post("/api/task/query/completions")
@ -1484,6 +1496,8 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user)
task_model_id = get_task_model_id(model_id)
print(task_model_id)
model = app.state.MODELS[task_model_id]
if app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE != "":
template = app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
else:
@ -1516,9 +1530,9 @@ Search Query:"""
),
"metadata": {"task": str(TASKS.QUERY_GENERATION)},
}
log.debug(payload)
print(payload)
# Handle pipeline filters
try:
payload = filter_pipeline(payload, user)
except Exception as e:
@ -1532,11 +1546,17 @@ Search Query:"""
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": str(e)},
)
if "chat_id" in payload:
del payload["chat_id"]
return await generate_chat_completions(form_data=payload, user=user)
# Check if task model is ollama model
if model["owned_by"] == "ollama":
payload = convert_payload_openai_to_ollama(payload)
form_data = GenerateChatCompletionForm(**payload)
response = await generate_ollama_chat_completion(form_data=form_data, user=user)
return convert_response_ollama_to_openai(response)
else:
return await generate_chat_completions(form_data=payload, user=user)
@app.post("/api/task/emoji/completions")
@ -1555,12 +1575,13 @@ async def generate_emoji(form_data: dict, user=Depends(get_verified_user)):
task_model_id = get_task_model_id(model_id)
print(task_model_id)
model = app.state.MODELS[task_model_id]
template = '''
Your task is to reflect the speaker's likely facial expression through a fitting emoji. Interpret emotions from the message and reflect their facial expression using fitting, diverse emojis (e.g., 😊, 😢, 😡, 😱).
Message: """{{prompt}}"""
'''
content = title_generation_template(
template,
form_data["prompt"],
@ -1584,9 +1605,9 @@ Message: """{{prompt}}"""
"chat_id": form_data.get("chat_id", None),
"metadata": {"task": str(TASKS.EMOJI_GENERATION)},
}
log.debug(payload)
# Handle pipeline filters
try:
payload = filter_pipeline(payload, user)
except Exception as e:
@ -1600,11 +1621,17 @@ Message: """{{prompt}}"""
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": str(e)},
)
if "chat_id" in payload:
del payload["chat_id"]
return await generate_chat_completions(form_data=payload, user=user)
# Check if task model is ollama model
if model["owned_by"] == "ollama":
payload = convert_payload_openai_to_ollama(payload)
form_data = GenerateChatCompletionForm(**payload)
response = await generate_ollama_chat_completion(form_data=form_data, user=user)
return convert_response_ollama_to_openai(response)
else:
return await generate_chat_completions(form_data=payload, user=user)
@app.post("/api/task/moa/completions")
@ -1623,6 +1650,8 @@ async def generate_moa_response(form_data: dict, user=Depends(get_verified_user)
task_model_id = get_task_model_id(model_id)
print(task_model_id)
model = app.state.MODELS[task_model_id]
template = """You have been provided with a set of responses from various models to the latest user query: "{{prompt}}"
Your task is to synthesize these responses into a single, high-quality response. It is crucial to critically evaluate the information provided in these responses, recognizing that some of it may be biased or incorrect. Your response should not simply replicate the given answers but should offer a refined, accurate, and comprehensive reply to the instruction. Ensure your response is well-structured, coherent, and adheres to the highest standards of accuracy and reliability.
@ -1635,8 +1664,6 @@ Responses from models: {{responses}}"""
form_data["responses"],
)
payload = {
"model": task_model_id,
"messages": [{"role": "user", "content": content}],
@ -1644,9 +1671,6 @@ Responses from models: {{responses}}"""
"chat_id": form_data.get("chat_id", None),
"metadata": {"task": str(TASKS.MOA_RESPONSE_GENERATION)},
}
log.debug(payload)
try:
@ -1662,11 +1686,17 @@ Responses from models: {{responses}}"""
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": str(e)},
)
if "chat_id" in payload:
del payload["chat_id"]
return await generate_chat_completions(form_data=payload, user=user)
# Check if task model is ollama model
if model["owned_by"] == "ollama":
payload = convert_payload_openai_to_ollama(payload)
form_data = GenerateChatCompletionForm(**payload)
response = await generate_ollama_chat_completion(form_data=form_data, user=user)
return convert_response_ollama_to_openai(response)
else:
return await generate_chat_completions(form_data=payload, user=user)
##################################

View File

@ -86,3 +86,49 @@ def apply_model_params_to_body_ollama(params: dict, form_data: dict) -> dict:
form_data[value] = param
return form_data
def convert_payload_openai_to_ollama(openai_payload: dict) -> dict:
"""
Converts a payload formatted for OpenAI's API to be compatible with Ollama's API endpoint for chat completions.
Args:
openai_payload (dict): The payload originally designed for OpenAI API usage.
Returns:
dict: A modified payload compatible with the Ollama API.
"""
ollama_payload = {}
# Mapping basic model and message details
ollama_payload["model"] = openai_payload.get("model")
ollama_payload["messages"] = openai_payload.get("messages")
ollama_payload["stream"] = openai_payload.get("stream", False)
# If there are advanced parameters in the payload, format them in Ollama's options field
ollama_options = {}
# Handle parameters which map directly
for param in ["temperature", "top_p", "seed"]:
if param in openai_payload:
ollama_options[param] = openai_payload[param]
# Mapping OpenAI's `max_tokens` -> Ollama's `num_predict`
if "max_completion_tokens" in openai_payload:
ollama_options["num_predict"] = openai_payload["max_completion_tokens"]
elif "max_tokens" in openai_payload:
ollama_options["num_predict"] = openai_payload["max_tokens"]
# Handle frequency / presence_penalty, which needs renaming and checking
if "frequency_penalty" in openai_payload:
ollama_options["repeat_penalty"] = openai_payload["frequency_penalty"]
if "presence_penalty" in openai_payload and "penalty" not in ollama_options:
# We are assuming presence penalty uses a similar concept in Ollama, which needs custom handling if exists.
ollama_options["new_topic_penalty"] = openai_payload["presence_penalty"]
# Add options to payload if any have been set
if ollama_options:
ollama_payload["options"] = ollama_options
return ollama_payload

View File

@ -0,0 +1,14 @@
from open_webui.utils.task import prompt_template
from open_webui.utils.misc import (
openai_chat_completion_message_template,
)
from typing import Callable, Optional
def convert_response_ollama_to_openai(ollama_response: dict) -> dict:
model = ollama_response.get("model", "ollama")
message_content = ollama_response.get("message", {}).get("content", "")
response = openai_chat_completion_message_template(model, message_content)
return response