diff --git a/pipelines/examples/mlx_pipeline.py b/pipelines/examples/mlx_pipeline.py index 71faa4f..c8dcd4a 100644 --- a/pipelines/examples/mlx_pipeline.py +++ b/pipelines/examples/mlx_pipeline.py @@ -5,15 +5,17 @@ date: 2024-05-22 version: 1.0 license: MIT description: A pipeline for running the mlx-lm server with a specified model. -dependencies: requests, mlx-lm -environment_variables: MLX_MODEL +dependencies: requests, mlx-lm, huggingface_hub +environment_variables: MLX_MODEL, MLX_STOP, HUGGINGFACE_TOKEN """ from typing import List, Union, Generator, Iterator -import requests import subprocess import os import socket +import time +import requests +from huggingface_hub import login from schemas import OpenAIChatMessage @@ -23,21 +25,30 @@ class Pipeline: self.id = "mlx_pipeline" self.name = "MLX Pipeline" self.process = None - self.model = os.getenv( - "MLX_MODEL", "mistralai/Mistral-7B-Instruct-v0.2" - ) # Default model if not set in environment variable + self.model = os.getenv('MLX_MODEL', 'mistralai/Mistral-7B-Instruct-v0.2') # Default model if not set in environment variable self.port = self.find_free_port() - self.stop_sequences = os.getenv( - "MLX_STOP", "[INST]" - ) # Stop sequences from environment variable + self.stop_sequences = os.getenv('MLX_STOP', '[INST]') # Stop sequences from environment variable + self.hf_token = os.getenv('HUGGINGFACE_TOKEN', None) # Hugging Face token from environment variable + + # Authenticate with Hugging Face if a token is provided + if self.hf_token: + self.authenticate_huggingface(self.hf_token) @staticmethod def find_free_port(): with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.bind(("", 0)) + s.bind(('', 0)) s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) return s.getsockname()[1] + @staticmethod + def authenticate_huggingface(token: str): + try: + login(token) + print("Successfully authenticated with Hugging Face.") + except Exception as e: + print(f"Failed to authenticate with Hugging Face: {e}") + async def on_startup(self): # This function is called when the server is started. print(f"on_startup:{__name__}") @@ -54,13 +65,18 @@ class Pipeline: self.process = subprocess.Popen( ["mlx_lm.server", "--model", self.model, "--port", str(self.port)], stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - ) - print( - f"Subprocess started with PID: {self.process.pid} on port {self.port}" + stderr=subprocess.PIPE ) + print(f"Subprocess started with PID: {self.process.pid} on port {self.port}") + + # Check if the process has started correctly + time.sleep(2) # Give it a moment to start + if self.process.poll() is not None: + raise RuntimeError(f"Subprocess failed to start. Return code: {self.process.returncode}") + except Exception as e: print(f"Failed to start subprocess: {e}") + self.process = None def stop_subprocess(self): # Stop the subprocess if it is running @@ -71,6 +87,8 @@ class Pipeline: print(f"Subprocess with PID {self.process.pid} terminated") except Exception as e: print(f"Failed to terminate subprocess: {e}") + finally: + self.process = None def get_response( self, user_message: str, messages: List[OpenAIChatMessage], body: dict @@ -78,9 +96,15 @@ class Pipeline: # This is where you can add your custom pipelines like RAG.' print(f"get_response:{__name__}") + if not self.process or self.process.poll() is not None: + return "Error: Subprocess is not running." + MLX_BASE_URL = f"http://localhost:{self.port}" MODEL = self.model + # Convert OpenAIChatMessage objects to dictionaries + messages_dict = [{"role": message.role, "content": message.content} for message in messages] + # Extract additional parameters from the body temperature = body.get("temperature", 0.8) max_tokens = body.get("max_tokens", 1000) @@ -96,12 +120,12 @@ class Pipeline: payload = { "model": MODEL, - "messages": [message.model_dump() for message in messages], + "messages": messages_dict, "temperature": temperature, "max_tokens": max_tokens, "top_p": top_p, "repetition_penalty": repetition_penalty, - "stop": stop, + "stop": stop } try: @@ -115,4 +139,4 @@ class Pipeline: return r.iter_lines() except Exception as e: - return f"Error: {e}" + return f"Error: {e}" \ No newline at end of file