diff --git a/pipelines/examples/mlx_pipeline.py b/pipelines/examples/mlx_pipeline.py index a79dcd6..85b7b30 100644 --- a/pipelines/examples/mlx_pipeline.py +++ b/pipelines/examples/mlx_pipeline.py @@ -1,118 +1,99 @@ """ title: MLX Pipeline author: justinh-rahb -date: 2024-05-22 -version: 1.0 +date: 2024-05-27 +version: 1.1 license: MIT -description: A pipeline for running the mlx-lm server with a specified model. -dependencies: requests, mlx-lm -environment_variables: MLX_MODEL +description: A pipeline for generating text using Apple MLX Framework. +dependencies: requests, mlx-lm, huggingface-hub +environment_variables: MLX_HOST, MLX_PORT, MLX_MODEL, MLX_STOP, MLX_SUBPROCESS, HUGGINGFACE_TOKEN """ from typing import List, Union, Generator, Iterator -import requests -import subprocess -import os -import socket from schemas import OpenAIChatMessage - +import requests +import os +import subprocess +import logging +from huggingface_hub import login class Pipeline: def __init__(self): - # Optionally, you can set the id and name of the pipeline. self.id = "mlx_pipeline" self.name = "MLX Pipeline" - self.process = None - self.model = os.getenv( - "MLX_MODEL", "mistralai/Mistral-7B-Instruct-v0.2" - ) # Default model if not set in environment variable - self.port = self.find_free_port() - self.stop_sequences = os.getenv( - "MLX_STOP", "[INST]" - ) # Stop sequences from environment variable + self.host = os.getenv("MLX_HOST", "localhost") + self.port = os.getenv("MLX_PORT", "8080") + self.model = os.getenv("MLX_MODEL", "mistralai/Mistral-7B-Instruct-v0.2") + self.stop_sequence = os.getenv("MLX_STOP", "[INST]").split(",") # Default stop sequence is [INST] + self.subprocess = os.getenv("MLX_SUBPROCESS", "true").lower() == "true" + self.huggingface_token = os.getenv("HUGGINGFACE_TOKEN", None) - @staticmethod - def find_free_port(): - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.bind(("", 0)) - s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - return s.getsockname()[1] + if self.huggingface_token: + login(self.huggingface_token) + + if self.subprocess: + self.start_mlx_server() + + def start_mlx_server(self): + if not os.getenv("MLX_PORT"): + self.port = self.find_free_port() + command = f"mlx_lm.server --model {self.model} --port {self.port}" + self.server_process = subprocess.Popen(command, shell=True) + logging.info(f"Started MLX server on port {self.port}") + + def find_free_port(self): + import socket + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.bind(('', 0)) + port = s.getsockname()[1] + s.close() + return port async def on_startup(self): - # This function is called when the server is started. - print(f"on_startup:{__name__}") - self.start_subprocess() + logging.info(f"on_startup:{__name__}") async def on_shutdown(self): - # This function is called when the server is stopped. - print(f"on_shutdown:{__name__}") - self.stop_subprocess() - - def start_subprocess(self): - # Start the subprocess for "mlx_lm.server --model ${MLX_MODEL} --port ${PORT}" - try: - self.process = subprocess.Popen( - ["mlx_lm.server", "--model", self.model, "--port", str(self.port)], - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - ) - print( - f"Subprocess started with PID: {self.process.pid} on port {self.port}" - ) - except Exception as e: - print(f"Failed to start subprocess: {e}") - - def stop_subprocess(self): - # Stop the subprocess if it is running - if self.process: - try: - self.process.terminate() - self.process.wait() - print(f"Subprocess with PID {self.process.pid} terminated") - except Exception as e: - print(f"Failed to terminate subprocess: {e}") + if self.subprocess and hasattr(self, 'server_process'): + self.server_process.terminate() + logging.info(f"Terminated MLX server on port {self.port}") def get_response( self, user_message: str, model_id: str, messages: List[dict], body: dict ) -> Union[str, Generator, Iterator]: - # This is where you can add your custom pipelines like RAG.' - print(f"get_response:{__name__}") + logging.info(f"get_response:{__name__}") - MLX_BASE_URL = f"http://localhost:{self.port}" - MODEL = self.model + url = f"http://{self.host}:{self.port}/v1/chat/completions" + headers = {"Content-Type": "application/json"} + + # Extract and validate parameters from the request body + max_tokens = body.get("max_tokens", 1024) + if not isinstance(max_tokens, int) or max_tokens < 0: + max_tokens = 1024 # Default to 1024 if invalid - # Extract additional parameters from the body temperature = body.get("temperature", 0.8) - max_tokens = body.get("max_tokens", 1000) - top_p = body.get("top_p", 1.0) - repetition_penalty = body.get("repetition_penalty", 1.0) - stop = self.stop_sequences + if not isinstance(temperature, (int, float)) or temperature < 0: + temperature = 0.8 # Default to 0.8 if invalid - if "user" in body: - print("######################################") - print(f'# User: {body["user"]["name"]} ({body["user"]["id"]})') - print(f"# Message: {user_message}") - print("######################################") + repeat_penalty = body.get("repeat_penalty", 1.0) + if not isinstance(repeat_penalty, (int, float)) or repeat_penalty < 0: + repeat_penalty = 1.0 # Default to 1.0 if invalid payload = { - "model": MODEL, - "messages": [message.model_dump() for message in messages], - "temperature": temperature, + "messages": messages, "max_tokens": max_tokens, - "top_p": top_p, - "repetition_penalty": repetition_penalty, - "stop": stop, + "temperature": temperature, + "repetition_penalty": repeat_penalty, + "stop": self.stop_sequence, + "stream": body.get("stream", False) } try: - r = requests.post( - url=f"{MLX_BASE_URL}/v1/chat/completions", - json=payload, - stream=True, - ) - + r = requests.post(url, headers=headers, json=payload, stream=body.get("stream", False)) r.raise_for_status() - return r.iter_lines() + if body.get("stream", False): + return r.iter_lines() + else: + return r.json() except Exception as e: return f"Error: {e}"