mirror of
https://github.com/open-webui/open-webui
synced 2024-11-16 05:24:02 +00:00
fix/refac: use ollama /api/chat endpoint for tasks
This commit is contained in:
parent
585b9eb84a
commit
41926172d3
@ -19,7 +19,9 @@ from open_webui.apps.audio.main import app as audio_app
|
||||
from open_webui.apps.images.main import app as images_app
|
||||
from open_webui.apps.ollama.main import app as ollama_app
|
||||
from open_webui.apps.ollama.main import (
|
||||
generate_openai_chat_completion as generate_ollama_chat_completion,
|
||||
GenerateChatCompletionForm,
|
||||
generate_chat_completion as generate_ollama_chat_completion,
|
||||
generate_openai_chat_completion as generate_ollama_openai_chat_completion,
|
||||
)
|
||||
from open_webui.apps.ollama.main import get_all_models as get_ollama_models
|
||||
from open_webui.apps.openai.main import app as openai_app
|
||||
@ -135,6 +137,9 @@ from open_webui.utils.utils import (
|
||||
)
|
||||
from open_webui.utils.webhook import post_webhook
|
||||
|
||||
from open_webui.utils.payload import convert_payload_openai_to_ollama
|
||||
from open_webui.utils.response import convert_response_ollama_to_openai
|
||||
|
||||
if SAFE_MODE:
|
||||
print("SAFE MODE ENABLED")
|
||||
Functions.deactivate_all_functions()
|
||||
@ -1048,7 +1053,7 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
|
||||
if model.get("pipe"):
|
||||
return await generate_function_chat_completion(form_data, user=user)
|
||||
if model["owned_by"] == "ollama":
|
||||
return await generate_ollama_chat_completion(form_data, user=user)
|
||||
return await generate_ollama_openai_chat_completion(form_data, user=user)
|
||||
else:
|
||||
return await generate_openai_chat_completion(form_data, user=user)
|
||||
|
||||
@ -1399,9 +1404,10 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)):
|
||||
# Check if the user has a custom task model
|
||||
# If the user has a custom task model, use that model
|
||||
task_model_id = get_task_model_id(model_id)
|
||||
|
||||
print(task_model_id)
|
||||
|
||||
model = app.state.MODELS[task_model_id]
|
||||
|
||||
if app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE != "":
|
||||
template = app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE
|
||||
else:
|
||||
@ -1440,9 +1446,9 @@ Prompt: {{prompt:middletruncate:8000}}"""
|
||||
"chat_id": form_data.get("chat_id", None),
|
||||
"metadata": {"task": str(TASKS.TITLE_GENERATION)},
|
||||
}
|
||||
|
||||
log.debug(payload)
|
||||
|
||||
# Handle pipeline filters
|
||||
try:
|
||||
payload = filter_pipeline(payload, user)
|
||||
except Exception as e:
|
||||
@ -1456,11 +1462,17 @@ Prompt: {{prompt:middletruncate:8000}}"""
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content={"detail": str(e)},
|
||||
)
|
||||
|
||||
if "chat_id" in payload:
|
||||
del payload["chat_id"]
|
||||
|
||||
return await generate_chat_completions(form_data=payload, user=user)
|
||||
# Check if task model is ollama model
|
||||
if model["owned_by"] == "ollama":
|
||||
payload = convert_payload_openai_to_ollama(payload)
|
||||
form_data = GenerateChatCompletionForm(**payload)
|
||||
response = await generate_ollama_chat_completion(form_data=form_data, user=user)
|
||||
return convert_response_ollama_to_openai(response)
|
||||
else:
|
||||
return await generate_chat_completions(form_data=payload, user=user)
|
||||
|
||||
|
||||
@app.post("/api/task/query/completions")
|
||||
@ -1484,6 +1496,8 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user)
|
||||
task_model_id = get_task_model_id(model_id)
|
||||
print(task_model_id)
|
||||
|
||||
model = app.state.MODELS[task_model_id]
|
||||
|
||||
if app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE != "":
|
||||
template = app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
|
||||
else:
|
||||
@ -1516,9 +1530,9 @@ Search Query:"""
|
||||
),
|
||||
"metadata": {"task": str(TASKS.QUERY_GENERATION)},
|
||||
}
|
||||
log.debug(payload)
|
||||
|
||||
print(payload)
|
||||
|
||||
# Handle pipeline filters
|
||||
try:
|
||||
payload = filter_pipeline(payload, user)
|
||||
except Exception as e:
|
||||
@ -1532,11 +1546,17 @@ Search Query:"""
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content={"detail": str(e)},
|
||||
)
|
||||
|
||||
if "chat_id" in payload:
|
||||
del payload["chat_id"]
|
||||
|
||||
return await generate_chat_completions(form_data=payload, user=user)
|
||||
# Check if task model is ollama model
|
||||
if model["owned_by"] == "ollama":
|
||||
payload = convert_payload_openai_to_ollama(payload)
|
||||
form_data = GenerateChatCompletionForm(**payload)
|
||||
response = await generate_ollama_chat_completion(form_data=form_data, user=user)
|
||||
return convert_response_ollama_to_openai(response)
|
||||
else:
|
||||
return await generate_chat_completions(form_data=payload, user=user)
|
||||
|
||||
|
||||
@app.post("/api/task/emoji/completions")
|
||||
@ -1555,12 +1575,13 @@ async def generate_emoji(form_data: dict, user=Depends(get_verified_user)):
|
||||
task_model_id = get_task_model_id(model_id)
|
||||
print(task_model_id)
|
||||
|
||||
model = app.state.MODELS[task_model_id]
|
||||
|
||||
template = '''
|
||||
Your task is to reflect the speaker's likely facial expression through a fitting emoji. Interpret emotions from the message and reflect their facial expression using fitting, diverse emojis (e.g., 😊, 😢, 😡, 😱).
|
||||
|
||||
Message: """{{prompt}}"""
|
||||
'''
|
||||
|
||||
content = title_generation_template(
|
||||
template,
|
||||
form_data["prompt"],
|
||||
@ -1584,9 +1605,9 @@ Message: """{{prompt}}"""
|
||||
"chat_id": form_data.get("chat_id", None),
|
||||
"metadata": {"task": str(TASKS.EMOJI_GENERATION)},
|
||||
}
|
||||
|
||||
log.debug(payload)
|
||||
|
||||
# Handle pipeline filters
|
||||
try:
|
||||
payload = filter_pipeline(payload, user)
|
||||
except Exception as e:
|
||||
@ -1600,11 +1621,17 @@ Message: """{{prompt}}"""
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content={"detail": str(e)},
|
||||
)
|
||||
|
||||
if "chat_id" in payload:
|
||||
del payload["chat_id"]
|
||||
|
||||
return await generate_chat_completions(form_data=payload, user=user)
|
||||
# Check if task model is ollama model
|
||||
if model["owned_by"] == "ollama":
|
||||
payload = convert_payload_openai_to_ollama(payload)
|
||||
form_data = GenerateChatCompletionForm(**payload)
|
||||
response = await generate_ollama_chat_completion(form_data=form_data, user=user)
|
||||
return convert_response_ollama_to_openai(response)
|
||||
else:
|
||||
return await generate_chat_completions(form_data=payload, user=user)
|
||||
|
||||
|
||||
@app.post("/api/task/moa/completions")
|
||||
@ -1623,6 +1650,8 @@ async def generate_moa_response(form_data: dict, user=Depends(get_verified_user)
|
||||
task_model_id = get_task_model_id(model_id)
|
||||
print(task_model_id)
|
||||
|
||||
model = app.state.MODELS[task_model_id]
|
||||
|
||||
template = """You have been provided with a set of responses from various models to the latest user query: "{{prompt}}"
|
||||
|
||||
Your task is to synthesize these responses into a single, high-quality response. It is crucial to critically evaluate the information provided in these responses, recognizing that some of it may be biased or incorrect. Your response should not simply replicate the given answers but should offer a refined, accurate, and comprehensive reply to the instruction. Ensure your response is well-structured, coherent, and adheres to the highest standards of accuracy and reliability.
|
||||
@ -1635,8 +1664,6 @@ Responses from models: {{responses}}"""
|
||||
form_data["responses"],
|
||||
)
|
||||
|
||||
|
||||
|
||||
payload = {
|
||||
"model": task_model_id,
|
||||
"messages": [{"role": "user", "content": content}],
|
||||
@ -1644,9 +1671,6 @@ Responses from models: {{responses}}"""
|
||||
"chat_id": form_data.get("chat_id", None),
|
||||
"metadata": {"task": str(TASKS.MOA_RESPONSE_GENERATION)},
|
||||
}
|
||||
|
||||
|
||||
|
||||
log.debug(payload)
|
||||
|
||||
try:
|
||||
@ -1662,11 +1686,17 @@ Responses from models: {{responses}}"""
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content={"detail": str(e)},
|
||||
)
|
||||
|
||||
if "chat_id" in payload:
|
||||
del payload["chat_id"]
|
||||
|
||||
return await generate_chat_completions(form_data=payload, user=user)
|
||||
# Check if task model is ollama model
|
||||
if model["owned_by"] == "ollama":
|
||||
payload = convert_payload_openai_to_ollama(payload)
|
||||
form_data = GenerateChatCompletionForm(**payload)
|
||||
response = await generate_ollama_chat_completion(form_data=form_data, user=user)
|
||||
return convert_response_ollama_to_openai(response)
|
||||
else:
|
||||
return await generate_chat_completions(form_data=payload, user=user)
|
||||
|
||||
|
||||
##################################
|
||||
|
@ -86,3 +86,49 @@ def apply_model_params_to_body_ollama(params: dict, form_data: dict) -> dict:
|
||||
form_data[value] = param
|
||||
|
||||
return form_data
|
||||
|
||||
|
||||
def convert_payload_openai_to_ollama(openai_payload: dict) -> dict:
|
||||
"""
|
||||
Converts a payload formatted for OpenAI's API to be compatible with Ollama's API endpoint for chat completions.
|
||||
|
||||
Args:
|
||||
openai_payload (dict): The payload originally designed for OpenAI API usage.
|
||||
|
||||
Returns:
|
||||
dict: A modified payload compatible with the Ollama API.
|
||||
"""
|
||||
ollama_payload = {}
|
||||
|
||||
# Mapping basic model and message details
|
||||
ollama_payload["model"] = openai_payload.get("model")
|
||||
ollama_payload["messages"] = openai_payload.get("messages")
|
||||
ollama_payload["stream"] = openai_payload.get("stream", False)
|
||||
|
||||
# If there are advanced parameters in the payload, format them in Ollama's options field
|
||||
ollama_options = {}
|
||||
|
||||
# Handle parameters which map directly
|
||||
for param in ["temperature", "top_p", "seed"]:
|
||||
if param in openai_payload:
|
||||
ollama_options[param] = openai_payload[param]
|
||||
|
||||
# Mapping OpenAI's `max_tokens` -> Ollama's `num_predict`
|
||||
if "max_completion_tokens" in openai_payload:
|
||||
ollama_options["num_predict"] = openai_payload["max_completion_tokens"]
|
||||
elif "max_tokens" in openai_payload:
|
||||
ollama_options["num_predict"] = openai_payload["max_tokens"]
|
||||
|
||||
# Handle frequency / presence_penalty, which needs renaming and checking
|
||||
if "frequency_penalty" in openai_payload:
|
||||
ollama_options["repeat_penalty"] = openai_payload["frequency_penalty"]
|
||||
|
||||
if "presence_penalty" in openai_payload and "penalty" not in ollama_options:
|
||||
# We are assuming presence penalty uses a similar concept in Ollama, which needs custom handling if exists.
|
||||
ollama_options["new_topic_penalty"] = openai_payload["presence_penalty"]
|
||||
|
||||
# Add options to payload if any have been set
|
||||
if ollama_options:
|
||||
ollama_payload["options"] = ollama_options
|
||||
|
||||
return ollama_payload
|
||||
|
14
backend/open_webui/utils/response.py
Normal file
14
backend/open_webui/utils/response.py
Normal file
@ -0,0 +1,14 @@
|
||||
from open_webui.utils.task import prompt_template
|
||||
from open_webui.utils.misc import (
|
||||
openai_chat_completion_message_template,
|
||||
)
|
||||
|
||||
from typing import Callable, Optional
|
||||
|
||||
|
||||
def convert_response_ollama_to_openai(ollama_response: dict) -> dict:
|
||||
model = ollama_response.get("model", "ollama")
|
||||
message_content = ollama_response.get("message", {}).get("content", "")
|
||||
|
||||
response = openai_chat_completion_message_template(model, message_content)
|
||||
return response
|
Loading…
Reference in New Issue
Block a user