From 2e934304cfb65a8b7d6af880dce0042827c73fc8 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Thu, 30 May 2024 02:04:07 -0700 Subject: [PATCH] feat: filter outlet & libretranslate example --- main.py | 2 +- .../libretranlsate_filter_pipeline.py | 140 ++++++++++++++++++ utils/main.py | 19 ++- 3 files changed, 158 insertions(+), 3 deletions(-) create mode 100644 pipelines/examples/libretranlsate_filter_pipeline.py diff --git a/main.py b/main.py index e42590f..1af5071 100644 --- a/main.py +++ b/main.py @@ -526,7 +526,7 @@ async def filter_outlet(pipeline_id: str, form_data: FilterForm): @app.post("/v1/chat/completions") @app.post("/chat/completions") async def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm): - user_message = get_last_user_message(form_data.messages) + user_message = get_last_user_message(form_data.messages.model_dump()) messages = [message.model_dump() for message in form_data.messages] if ( diff --git a/pipelines/examples/libretranlsate_filter_pipeline.py b/pipelines/examples/libretranlsate_filter_pipeline.py new file mode 100644 index 0000000..5a8745b --- /dev/null +++ b/pipelines/examples/libretranlsate_filter_pipeline.py @@ -0,0 +1,140 @@ +from typing import List, Optional +from schemas import OpenAIChatMessage +from pydantic import BaseModel +import requests +import os + +from utils.main import get_last_user_message, get_last_assistant_message + + +class Pipeline: + def __init__(self): + # Pipeline filters are only compatible with Open WebUI + # You can think of filter pipeline as a middleware that can be used to edit the form data before it is sent to the OpenAI API. + self.type = "filter" + + # Optionally, you can set the id and name of the pipeline. + # Assign a unique identifier to the pipeline. + # The identifier must be unique across all pipelines. + # The identifier must be an alphanumeric string that can include underscores or hyphens. It cannot contain spaces, special characters, slashes, or backslashes. + self.id = "libretranslate_filter_pipeline" + self.name = "LibreTranslate Filter" + + class Valves(BaseModel): + # List target pipeline ids (models) that this filter will be connected to. + # If you want to connect this filter to all pipelines, you can set pipelines to ["*"] + # e.g. ["llama3:latest", "gpt-3.5-turbo"] + pipelines: List[str] = [] + + # Assign a priority level to the filter pipeline. + # The priority level determines the order in which the filter pipelines are executed. + # The lower the number, the higher the priority. + priority: int = 0 + + # Valves + libretranslate_url: str + + # Source and target languages + # User message will be translated from source_user to target_user + source_user: Optional[str] = "auto" + target_user: Optional[str] = "en" + + # Assistant languages + # Assistant message will be translated from source_assistant to target_assistant + source_assistant: Optional[str] = "en" + target_assistant: Optional[str] = "es" + + # Initialize + self.valves = Valves( + **{ + "pipelines": ["*"], # Connect to all pipelines + "libretranslate_url": os.getenv( + "LIBRETRANSLATE_API_BASE_URL", "http://localhost:5000" + ), + } + ) + + pass + + async def on_startup(self): + # This function is called when the server is started. + print(f"on_startup:{__name__}") + pass + + async def on_shutdown(self): + # This function is called when the server is stopped. + print(f"on_shutdown:{__name__}") + pass + + async def on_valves_updated(self): + # This function is called when the valves are updated. + pass + + def translate(self, text: str, source: str, target: str) -> str: + payload = { + "q": text, + "source": source, + "target": target, + } + + try: + r = requests.post( + f"{self.valves.libretranslate_url}/translate", json=payload + ) + r.raise_for_status() + + data = r.json() + return data["translatedText"] + except Exception as e: + print(f"Error translating text: {e}") + return text + + async def inlet(self, body: dict, user: Optional[dict] = None) -> dict: + print(f"inlet:{__name__}") + + messages = body["messages"] + user_message = get_last_user_message(messages) + + print(f"User message: {user_message}") + + # Translate user message + translated_user_message = self.translate( + user_message, + self.valves.source_user, + self.valves.target_user, + ) + + print(f"Translated user message: {translated_user_message}") + + for message in reversed(messages): + if message["role"] == "user": + message["content"] = translated_user_message + break + + body = {**body, "messages": messages} + return body + + async def outlet(self, body: dict, user: Optional[dict] = None) -> dict: + print(f"outlet:{__name__}") + + messages = body["messages"] + assistant_message = get_last_assistant_message(messages) + + print(f"Assistant message: {assistant_message}") + + # Translate assistant message + translated_assistant_message = self.translate( + assistant_message, + self.valves.source_assistant, + self.valves.target_assistant, + ) + + print(f"Translated assistant message: {translated_assistant_message}") + + for message in reversed(messages): + if message["role"] == "assistant": + message["content"] = translated_assistant_message + break + + body = {**body, "messages": messages} + return body diff --git a/utils/main.py b/utils/main.py index a50cddf..96026ea 100644 --- a/utils/main.py +++ b/utils/main.py @@ -24,6 +24,21 @@ def stream_message_template(model: str, message: str): def get_last_user_message(messages: List[dict]) -> str: for message in reversed(messages): - if message.role == "user": - return message.content + if message["role"] == "user": + if isinstance(message["content"], list): + for item in message["content"]: + if item["type"] == "text": + return item["text"] + return message["content"] + return None + + +def get_last_assistant_message(messages: List[dict]) -> str: + for message in reversed(messages): + if message["role"] == "assistant": + if isinstance(message["content"], list): + for item in message["content"]: + if item["type"] == "text": + return item["text"] + return message["content"] return None