diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 78bb587dd..55f9acfdf 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -19,7 +19,9 @@ from open_webui.apps.audio.main import app as audio_app from open_webui.apps.images.main import app as images_app from open_webui.apps.ollama.main import app as ollama_app from open_webui.apps.ollama.main import ( - generate_openai_chat_completion as generate_ollama_chat_completion, + GenerateChatCompletionForm, + generate_chat_completion as generate_ollama_chat_completion, + generate_openai_chat_completion as generate_ollama_openai_chat_completion, ) from open_webui.apps.ollama.main import get_all_models as get_ollama_models from open_webui.apps.openai.main import app as openai_app @@ -135,6 +137,9 @@ from open_webui.utils.utils import ( ) from open_webui.utils.webhook import post_webhook +from open_webui.utils.payload import convert_payload_openai_to_ollama +from open_webui.utils.response import convert_response_ollama_to_openai + if SAFE_MODE: print("SAFE MODE ENABLED") Functions.deactivate_all_functions() @@ -1048,7 +1053,7 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u if model.get("pipe"): return await generate_function_chat_completion(form_data, user=user) if model["owned_by"] == "ollama": - return await generate_ollama_chat_completion(form_data, user=user) + return await generate_ollama_openai_chat_completion(form_data, user=user) else: return await generate_openai_chat_completion(form_data, user=user) @@ -1399,9 +1404,10 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)): # Check if the user has a custom task model # If the user has a custom task model, use that model task_model_id = get_task_model_id(model_id) - print(task_model_id) + model = app.state.MODELS[task_model_id] + if app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE != "": template = app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE else: @@ -1440,9 +1446,9 @@ Prompt: {{prompt:middletruncate:8000}}""" "chat_id": form_data.get("chat_id", None), "metadata": {"task": str(TASKS.TITLE_GENERATION)}, } - log.debug(payload) + # Handle pipeline filters try: payload = filter_pipeline(payload, user) except Exception as e: @@ -1456,11 +1462,17 @@ Prompt: {{prompt:middletruncate:8000}}""" status_code=status.HTTP_400_BAD_REQUEST, content={"detail": str(e)}, ) - if "chat_id" in payload: del payload["chat_id"] - return await generate_chat_completions(form_data=payload, user=user) + # Check if task model is ollama model + if model["owned_by"] == "ollama": + payload = convert_payload_openai_to_ollama(payload) + form_data = GenerateChatCompletionForm(**payload) + response = await generate_ollama_chat_completion(form_data=form_data, user=user) + return convert_response_ollama_to_openai(response) + else: + return await generate_chat_completions(form_data=payload, user=user) @app.post("/api/task/query/completions") @@ -1484,6 +1496,8 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user) task_model_id = get_task_model_id(model_id) print(task_model_id) + model = app.state.MODELS[task_model_id] + if app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE != "": template = app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE else: @@ -1516,9 +1530,9 @@ Search Query:""" ), "metadata": {"task": str(TASKS.QUERY_GENERATION)}, } + log.debug(payload) - print(payload) - + # Handle pipeline filters try: payload = filter_pipeline(payload, user) except Exception as e: @@ -1532,11 +1546,17 @@ Search Query:""" status_code=status.HTTP_400_BAD_REQUEST, content={"detail": str(e)}, ) - if "chat_id" in payload: del payload["chat_id"] - return await generate_chat_completions(form_data=payload, user=user) + # Check if task model is ollama model + if model["owned_by"] == "ollama": + payload = convert_payload_openai_to_ollama(payload) + form_data = GenerateChatCompletionForm(**payload) + response = await generate_ollama_chat_completion(form_data=form_data, user=user) + return convert_response_ollama_to_openai(response) + else: + return await generate_chat_completions(form_data=payload, user=user) @app.post("/api/task/emoji/completions") @@ -1555,12 +1575,13 @@ async def generate_emoji(form_data: dict, user=Depends(get_verified_user)): task_model_id = get_task_model_id(model_id) print(task_model_id) + model = app.state.MODELS[task_model_id] + template = ''' Your task is to reflect the speaker's likely facial expression through a fitting emoji. Interpret emotions from the message and reflect their facial expression using fitting, diverse emojis (e.g., 😊, 😢, 😡, 😱). Message: """{{prompt}}""" ''' - content = title_generation_template( template, form_data["prompt"], @@ -1584,9 +1605,9 @@ Message: """{{prompt}}""" "chat_id": form_data.get("chat_id", None), "metadata": {"task": str(TASKS.EMOJI_GENERATION)}, } - log.debug(payload) + # Handle pipeline filters try: payload = filter_pipeline(payload, user) except Exception as e: @@ -1600,11 +1621,17 @@ Message: """{{prompt}}""" status_code=status.HTTP_400_BAD_REQUEST, content={"detail": str(e)}, ) - if "chat_id" in payload: del payload["chat_id"] - return await generate_chat_completions(form_data=payload, user=user) + # Check if task model is ollama model + if model["owned_by"] == "ollama": + payload = convert_payload_openai_to_ollama(payload) + form_data = GenerateChatCompletionForm(**payload) + response = await generate_ollama_chat_completion(form_data=form_data, user=user) + return convert_response_ollama_to_openai(response) + else: + return await generate_chat_completions(form_data=payload, user=user) @app.post("/api/task/moa/completions") @@ -1623,6 +1650,8 @@ async def generate_moa_response(form_data: dict, user=Depends(get_verified_user) task_model_id = get_task_model_id(model_id) print(task_model_id) + model = app.state.MODELS[task_model_id] + template = """You have been provided with a set of responses from various models to the latest user query: "{{prompt}}" Your task is to synthesize these responses into a single, high-quality response. It is crucial to critically evaluate the information provided in these responses, recognizing that some of it may be biased or incorrect. Your response should not simply replicate the given answers but should offer a refined, accurate, and comprehensive reply to the instruction. Ensure your response is well-structured, coherent, and adheres to the highest standards of accuracy and reliability. @@ -1635,8 +1664,6 @@ Responses from models: {{responses}}""" form_data["responses"], ) - - payload = { "model": task_model_id, "messages": [{"role": "user", "content": content}], @@ -1644,9 +1671,6 @@ Responses from models: {{responses}}""" "chat_id": form_data.get("chat_id", None), "metadata": {"task": str(TASKS.MOA_RESPONSE_GENERATION)}, } - - - log.debug(payload) try: @@ -1662,11 +1686,17 @@ Responses from models: {{responses}}""" status_code=status.HTTP_400_BAD_REQUEST, content={"detail": str(e)}, ) - if "chat_id" in payload: del payload["chat_id"] - return await generate_chat_completions(form_data=payload, user=user) + # Check if task model is ollama model + if model["owned_by"] == "ollama": + payload = convert_payload_openai_to_ollama(payload) + form_data = GenerateChatCompletionForm(**payload) + response = await generate_ollama_chat_completion(form_data=form_data, user=user) + return convert_response_ollama_to_openai(response) + else: + return await generate_chat_completions(form_data=payload, user=user) ################################## diff --git a/backend/open_webui/utils/payload.py b/backend/open_webui/utils/payload.py index b2654cd25..72aec6a6c 100644 --- a/backend/open_webui/utils/payload.py +++ b/backend/open_webui/utils/payload.py @@ -86,3 +86,49 @@ def apply_model_params_to_body_ollama(params: dict, form_data: dict) -> dict: form_data[value] = param return form_data + + +def convert_payload_openai_to_ollama(openai_payload: dict) -> dict: + """ + Converts a payload formatted for OpenAI's API to be compatible with Ollama's API endpoint for chat completions. + + Args: + openai_payload (dict): The payload originally designed for OpenAI API usage. + + Returns: + dict: A modified payload compatible with the Ollama API. + """ + ollama_payload = {} + + # Mapping basic model and message details + ollama_payload["model"] = openai_payload.get("model") + ollama_payload["messages"] = openai_payload.get("messages") + ollama_payload["stream"] = openai_payload.get("stream", False) + + # If there are advanced parameters in the payload, format them in Ollama's options field + ollama_options = {} + + # Handle parameters which map directly + for param in ["temperature", "top_p", "seed"]: + if param in openai_payload: + ollama_options[param] = openai_payload[param] + + # Mapping OpenAI's `max_tokens` -> Ollama's `num_predict` + if "max_completion_tokens" in openai_payload: + ollama_options["num_predict"] = openai_payload["max_completion_tokens"] + elif "max_tokens" in openai_payload: + ollama_options["num_predict"] = openai_payload["max_tokens"] + + # Handle frequency / presence_penalty, which needs renaming and checking + if "frequency_penalty" in openai_payload: + ollama_options["repeat_penalty"] = openai_payload["frequency_penalty"] + + if "presence_penalty" in openai_payload and "penalty" not in ollama_options: + # We are assuming presence penalty uses a similar concept in Ollama, which needs custom handling if exists. + ollama_options["new_topic_penalty"] = openai_payload["presence_penalty"] + + # Add options to payload if any have been set + if ollama_options: + ollama_payload["options"] = ollama_options + + return ollama_payload diff --git a/backend/open_webui/utils/response.py b/backend/open_webui/utils/response.py new file mode 100644 index 000000000..22275488f --- /dev/null +++ b/backend/open_webui/utils/response.py @@ -0,0 +1,14 @@ +from open_webui.utils.task import prompt_template +from open_webui.utils.misc import ( + openai_chat_completion_message_template, +) + +from typing import Callable, Optional + + +def convert_response_ollama_to_openai(ollama_response: dict) -> dict: + model = ollama_response.get("model", "ollama") + message_content = ollama_response.get("message", {}).get("content", "") + + response = openai_chat_completion_message_template(model, message_content) + return response