diff --git a/examples/pipelines/providers/mlx_manifold_pipeline.py b/examples/pipelines/providers/mlx_manifold_pipeline.py index 51181a5..e8c27d1 100644 --- a/examples/pipelines/providers/mlx_manifold_pipeline.py +++ b/examples/pipelines/providers/mlx_manifold_pipeline.py @@ -6,10 +6,8 @@ version: 2.0 license: MIT description: A pipeline for generating text using Apple MLX Framework with dynamic model loading. requirements: requests, mlx-lm, huggingface-hub, psutil -environment_variables: MLX_HOST, MLX_PORT, MLX_SUBPROCESS """ -import argparse from typing import List, Union, Generator, Iterator from schemas import OpenAIChatMessage from pydantic import BaseModel @@ -32,30 +30,35 @@ class Pipeline: MLX_USE_DEFAULT_CHAT_TEMPLATE: bool = False def __init__(self): + # Pipeline identification self.type = "manifold" self.id = "mlx" self.name = "MLX/" + # Initialize valves and update them self.valves = self.Valves() self.update_valves() - self.host = os.getenv("MLX_HOST", "localhost") - self.port = os.getenv("MLX_PORT", "8080") - self.subprocess = os.getenv("MLX_SUBPROCESS", "true").lower() == "true" + # Server configuration + self.host = "localhost" # Always use localhost for security + self.port = None # Port will be dynamically assigned + # Model management self.models = self.get_mlx_models() self.current_model = None self.server_process = None - if self.subprocess: - self.start_mlx_server(self.valves.MLX_DEFAULT_MODEL) + # Start the MLX server with the default model + self.start_mlx_server(self.valves.MLX_DEFAULT_MODEL) def update_valves(self): + """Update pipeline configuration based on valve settings.""" if self.valves.HUGGINGFACE_TOKEN: login(self.valves.HUGGINGFACE_TOKEN) self.stop_sequence = self.valves.MLX_STOP.split(",") def get_mlx_models(self): + """Fetch available MLX models based on the specified pattern.""" try: cmd = [ 'mlx_lm.manage', @@ -65,11 +68,10 @@ class Pipeline: result = subprocess.run(cmd, capture_output=True, text=True) lines = result.stdout.strip().split('\n') - # Skip header lines and the line with dashes content_lines = [line for line in lines if line and not line.startswith('-')] models = [] - for line in content_lines[2:]: # Skip the first two lines (header) + for line in content_lines[2:]: # Skip header lines parts = line.split() if len(parts) >= 2: repo_id = parts[0] @@ -93,9 +95,11 @@ class Pipeline: }] def pipelines(self) -> List[dict]: + """Return the list of available models as pipelines.""" return self.models def start_mlx_server(self, model_name): + """Start the MLX server with the specified model.""" model_id = f"mlx.{model_name.split('/')[-1].lower()}" if self.current_model == model_id and self.server_process and self.server_process.poll() is None: logging.info(f"MLX server already running with model {model_name}") @@ -103,8 +107,7 @@ class Pipeline: self.stop_mlx_server() - if not os.getenv("MLX_PORT"): - self.port = self.find_free_port() + self.port = self.find_free_port() command = [ "mlx_lm.server", @@ -112,6 +115,7 @@ class Pipeline: "--port", str(self.port), ] + # Add chat template options if specified if self.valves.MLX_CHAT_TEMPLATE: command.extend(["--chat-template", self.valves.MLX_CHAT_TEMPLATE]) elif self.valves.MLX_USE_DEFAULT_CHAT_TEMPLATE: @@ -124,6 +128,7 @@ class Pipeline: time.sleep(5) # Give the server some time to start up def stop_mlx_server(self): + """Stop the currently running MLX server.""" if self.server_process: try: process = psutil.Process(self.server_process.pid) @@ -138,9 +143,11 @@ class Pipeline: finally: self.server_process = None self.current_model = None - logging.info(f"Stopped MLX server on port {self.port}") + self.port = None + logging.info("Stopped MLX server") def find_free_port(self): + """Find and return a free port to use for the MLX server.""" import socket s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s.bind(("", 0)) @@ -149,23 +156,26 @@ class Pipeline: return port async def on_startup(self): + """Perform any necessary startup operations.""" logging.info(f"on_startup:{__name__}") async def on_shutdown(self): - if self.subprocess: - self.stop_mlx_server() + """Perform cleanup operations on shutdown.""" + self.stop_mlx_server() async def on_valves_updated(self): + """Handle updates to the pipeline configuration.""" self.update_valves() self.models = self.get_mlx_models() - if self.subprocess: - self.start_mlx_server(self.valves.MLX_DEFAULT_MODEL) + self.start_mlx_server(self.valves.MLX_DEFAULT_MODEL) def pipe( self, user_message: str, model_id: str, messages: List[dict], body: dict ) -> Union[str, Generator, Iterator]: + """Process a request through the MLX pipeline.""" logging.info(f"pipe:{__name__}") + # Switch model if necessary if model_id != self.current_model: model_name = next((model['name'] for model in self.models if model['id'] == model_id), self.valves.MLX_DEFAULT_MODEL) self.start_mlx_server(model_name) @@ -173,6 +183,7 @@ class Pipeline: url = f"http://{self.host}:{self.port}/v1/chat/completions" headers = {"Content-Type": "application/json"} + # Prepare the payload for the MLX server max_tokens = body.get("max_tokens", 4096) temperature = body.get("temperature", 0.8) repeat_penalty = body.get("repeat_penalty", 1.0) @@ -187,11 +198,13 @@ class Pipeline: } try: + # Send request to MLX server r = requests.post( url, headers=headers, json=payload, stream=body.get("stream", False) ) r.raise_for_status() + # Return streamed response or full JSON response if body.get("stream", False): return r.iter_lines() else: