Remove MLX_SUBPROCESS=True option

This commit is contained in:
Justin Hayes 2024-07-01 11:13:40 -04:00 committed by GitHub
parent 72e933cd6b
commit 83db4c2035
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -6,10 +6,8 @@ version: 2.0
license: MIT license: MIT
description: A pipeline for generating text using Apple MLX Framework with dynamic model loading. description: A pipeline for generating text using Apple MLX Framework with dynamic model loading.
requirements: requests, mlx-lm, huggingface-hub, psutil requirements: requests, mlx-lm, huggingface-hub, psutil
environment_variables: MLX_HOST, MLX_PORT, MLX_SUBPROCESS
""" """
import argparse
from typing import List, Union, Generator, Iterator from typing import List, Union, Generator, Iterator
from schemas import OpenAIChatMessage from schemas import OpenAIChatMessage
from pydantic import BaseModel from pydantic import BaseModel
@ -32,30 +30,35 @@ class Pipeline:
MLX_USE_DEFAULT_CHAT_TEMPLATE: bool = False MLX_USE_DEFAULT_CHAT_TEMPLATE: bool = False
def __init__(self): def __init__(self):
# Pipeline identification
self.type = "manifold" self.type = "manifold"
self.id = "mlx" self.id = "mlx"
self.name = "MLX/" self.name = "MLX/"
# Initialize valves and update them
self.valves = self.Valves() self.valves = self.Valves()
self.update_valves() self.update_valves()
self.host = os.getenv("MLX_HOST", "localhost") # Server configuration
self.port = os.getenv("MLX_PORT", "8080") self.host = "localhost" # Always use localhost for security
self.subprocess = os.getenv("MLX_SUBPROCESS", "true").lower() == "true" self.port = None # Port will be dynamically assigned
# Model management
self.models = self.get_mlx_models() self.models = self.get_mlx_models()
self.current_model = None self.current_model = None
self.server_process = None self.server_process = None
if self.subprocess: # Start the MLX server with the default model
self.start_mlx_server(self.valves.MLX_DEFAULT_MODEL) self.start_mlx_server(self.valves.MLX_DEFAULT_MODEL)
def update_valves(self): def update_valves(self):
"""Update pipeline configuration based on valve settings."""
if self.valves.HUGGINGFACE_TOKEN: if self.valves.HUGGINGFACE_TOKEN:
login(self.valves.HUGGINGFACE_TOKEN) login(self.valves.HUGGINGFACE_TOKEN)
self.stop_sequence = self.valves.MLX_STOP.split(",") self.stop_sequence = self.valves.MLX_STOP.split(",")
def get_mlx_models(self): def get_mlx_models(self):
"""Fetch available MLX models based on the specified pattern."""
try: try:
cmd = [ cmd = [
'mlx_lm.manage', 'mlx_lm.manage',
@ -65,11 +68,10 @@ class Pipeline:
result = subprocess.run(cmd, capture_output=True, text=True) result = subprocess.run(cmd, capture_output=True, text=True)
lines = result.stdout.strip().split('\n') lines = result.stdout.strip().split('\n')
# Skip header lines and the line with dashes
content_lines = [line for line in lines if line and not line.startswith('-')] content_lines = [line for line in lines if line and not line.startswith('-')]
models = [] models = []
for line in content_lines[2:]: # Skip the first two lines (header) for line in content_lines[2:]: # Skip header lines
parts = line.split() parts = line.split()
if len(parts) >= 2: if len(parts) >= 2:
repo_id = parts[0] repo_id = parts[0]
@ -93,9 +95,11 @@ class Pipeline:
}] }]
def pipelines(self) -> List[dict]: def pipelines(self) -> List[dict]:
"""Return the list of available models as pipelines."""
return self.models return self.models
def start_mlx_server(self, model_name): def start_mlx_server(self, model_name):
"""Start the MLX server with the specified model."""
model_id = f"mlx.{model_name.split('/')[-1].lower()}" model_id = f"mlx.{model_name.split('/')[-1].lower()}"
if self.current_model == model_id and self.server_process and self.server_process.poll() is None: if self.current_model == model_id and self.server_process and self.server_process.poll() is None:
logging.info(f"MLX server already running with model {model_name}") logging.info(f"MLX server already running with model {model_name}")
@ -103,7 +107,6 @@ class Pipeline:
self.stop_mlx_server() self.stop_mlx_server()
if not os.getenv("MLX_PORT"):
self.port = self.find_free_port() self.port = self.find_free_port()
command = [ command = [
@ -112,6 +115,7 @@ class Pipeline:
"--port", str(self.port), "--port", str(self.port),
] ]
# Add chat template options if specified
if self.valves.MLX_CHAT_TEMPLATE: if self.valves.MLX_CHAT_TEMPLATE:
command.extend(["--chat-template", self.valves.MLX_CHAT_TEMPLATE]) command.extend(["--chat-template", self.valves.MLX_CHAT_TEMPLATE])
elif self.valves.MLX_USE_DEFAULT_CHAT_TEMPLATE: elif self.valves.MLX_USE_DEFAULT_CHAT_TEMPLATE:
@ -124,6 +128,7 @@ class Pipeline:
time.sleep(5) # Give the server some time to start up time.sleep(5) # Give the server some time to start up
def stop_mlx_server(self): def stop_mlx_server(self):
"""Stop the currently running MLX server."""
if self.server_process: if self.server_process:
try: try:
process = psutil.Process(self.server_process.pid) process = psutil.Process(self.server_process.pid)
@ -138,9 +143,11 @@ class Pipeline:
finally: finally:
self.server_process = None self.server_process = None
self.current_model = None self.current_model = None
logging.info(f"Stopped MLX server on port {self.port}") self.port = None
logging.info("Stopped MLX server")
def find_free_port(self): def find_free_port(self):
"""Find and return a free port to use for the MLX server."""
import socket import socket
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.bind(("", 0)) s.bind(("", 0))
@ -149,23 +156,26 @@ class Pipeline:
return port return port
async def on_startup(self): async def on_startup(self):
"""Perform any necessary startup operations."""
logging.info(f"on_startup:{__name__}") logging.info(f"on_startup:{__name__}")
async def on_shutdown(self): async def on_shutdown(self):
if self.subprocess: """Perform cleanup operations on shutdown."""
self.stop_mlx_server() self.stop_mlx_server()
async def on_valves_updated(self): async def on_valves_updated(self):
"""Handle updates to the pipeline configuration."""
self.update_valves() self.update_valves()
self.models = self.get_mlx_models() self.models = self.get_mlx_models()
if self.subprocess:
self.start_mlx_server(self.valves.MLX_DEFAULT_MODEL) self.start_mlx_server(self.valves.MLX_DEFAULT_MODEL)
def pipe( def pipe(
self, user_message: str, model_id: str, messages: List[dict], body: dict self, user_message: str, model_id: str, messages: List[dict], body: dict
) -> Union[str, Generator, Iterator]: ) -> Union[str, Generator, Iterator]:
"""Process a request through the MLX pipeline."""
logging.info(f"pipe:{__name__}") logging.info(f"pipe:{__name__}")
# Switch model if necessary
if model_id != self.current_model: if model_id != self.current_model:
model_name = next((model['name'] for model in self.models if model['id'] == model_id), self.valves.MLX_DEFAULT_MODEL) model_name = next((model['name'] for model in self.models if model['id'] == model_id), self.valves.MLX_DEFAULT_MODEL)
self.start_mlx_server(model_name) self.start_mlx_server(model_name)
@ -173,6 +183,7 @@ class Pipeline:
url = f"http://{self.host}:{self.port}/v1/chat/completions" url = f"http://{self.host}:{self.port}/v1/chat/completions"
headers = {"Content-Type": "application/json"} headers = {"Content-Type": "application/json"}
# Prepare the payload for the MLX server
max_tokens = body.get("max_tokens", 4096) max_tokens = body.get("max_tokens", 4096)
temperature = body.get("temperature", 0.8) temperature = body.get("temperature", 0.8)
repeat_penalty = body.get("repeat_penalty", 1.0) repeat_penalty = body.get("repeat_penalty", 1.0)
@ -187,11 +198,13 @@ class Pipeline:
} }
try: try:
# Send request to MLX server
r = requests.post( r = requests.post(
url, headers=headers, json=payload, stream=body.get("stream", False) url, headers=headers, json=payload, stream=body.get("stream", False)
) )
r.raise_for_status() r.raise_for_status()
# Return streamed response or full JSON response
if body.get("stream", False): if body.get("stream", False):
return r.iter_lines() return r.iter_lines()
else: else: