Update mlx_pipeline.py

This commit is contained in:
Justin Hayes 2024-05-27 11:43:39 -04:00 committed by GitHub
parent 61802b0d95
commit d23c2c48e7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,112 +1,91 @@
""" """
Plugin Name: MLX Pipeline title: MLX Pipeline
Description: A pipeline for running the mlx-lm server with a specified model and dynamically allocated port. author: justinh-rahb
Author: justinh-rahb date: 2024-05-27
License: MIT version: 1.1
Python Dependencies: requests, subprocess, os, socket, schemas license: MIT
description: A pipeline for generating text using Apple MLX Framework.
dependencies: requests, mlx-lm, huggingface-hub
environment_variables: MLX_HOST, MLX_PORT, MLX_MODEL, MLX_STOP, MLX_SUBPROCESS, HUGGINGFACE_TOKEN
""" """
from typing import List, Union, Generator, Iterator from typing import List, Union, Generator, Iterator
import requests
import subprocess
import os
import socket
from schemas import OpenAIChatMessage from schemas import OpenAIChatMessage
import requests
import os
import subprocess
import logging
from huggingface_hub import login
class Pipeline: class Pipeline:
def __init__(self): def __init__(self):
# Optionally, you can set the id and name of the pipeline.
self.id = "mlx_pipeline" self.id = "mlx_pipeline"
self.name = "MLX Pipeline" self.name = "MLX Pipeline"
self.process = None self.host = os.getenv("MLX_HOST", "localhost")
self.model = os.getenv('MLX_MODEL', 'mistralai/Mistral-7B-Instruct-v0.2') # Default model if not set in environment variable self.port = os.getenv("MLX_PORT", "8080")
self.port = self.find_free_port() self.model = os.getenv("MLX_MODEL", "mistralai/Mistral-7B-Instruct-v0.2")
self.stop_sequences = os.getenv('MLX_STOP', '[INST]') # Stop sequences from environment variable self.stop_sequence = os.getenv("MLX_STOP", "[INST]").split(",") # Default stop sequence is [INST]
self.subprocess = os.getenv("MLX_SUBPROCESS", "true").lower() == "true"
self.huggingface_token = os.getenv("HUGGINGFACE_TOKEN", None)
@staticmethod if self.huggingface_token:
def find_free_port(): login(self.huggingface_token)
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(('', 0)) if self.subprocess:
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) self.start_mlx_server()
return s.getsockname()[1]
def start_mlx_server(self):
if not os.getenv("MLX_PORT"):
self.port = self.find_free_port()
command = f"mlx_lm.server --model {self.model} --port {self.port}"
self.server_process = subprocess.Popen(command, shell=True)
logging.info(f"Started MLX server on port {self.port}")
def find_free_port(self):
import socket
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.bind(('', 0))
port = s.getsockname()[1]
s.close()
return port
async def on_startup(self): async def on_startup(self):
# This function is called when the server is started. logging.info(f"on_startup:{__name__}")
print(f"on_startup:{__name__}")
self.start_subprocess()
async def on_shutdown(self): async def on_shutdown(self):
# This function is called when the server is stopped. if self.subprocess and hasattr(self, 'server_process'):
print(f"on_shutdown:{__name__}") self.server_process.terminate()
self.stop_subprocess() logging.info(f"Terminated MLX server on port {self.port}")
def start_subprocess(self):
# Start the subprocess for "mlx_lm.server --model ${MLX_MODEL} --port ${PORT}"
try:
self.process = subprocess.Popen(
["mlx_lm.server", "--model", self.model, "--port", str(self.port)],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE
)
print(f"Subprocess started with PID: {self.process.pid} on port {self.port}")
except Exception as e:
print(f"Failed to start subprocess: {e}")
def stop_subprocess(self):
# Stop the subprocess if it is running
if self.process:
try:
self.process.terminate()
self.process.wait()
print(f"Subprocess with PID {self.process.pid} terminated")
except Exception as e:
print(f"Failed to terminate subprocess: {e}")
def get_response( def get_response(
self, user_message: str, model_id: str, messages: List[dict], body: dict self, user_message: str, model_id: str, messages: List[dict], body: dict
) -> Union[str, Generator, Iterator]: ) -> Union[str, Generator, Iterator]:
# This is where you can add your custom pipelines like RAG.' logging.info(f"get_response:{__name__}")
print(f"get_response:{__name__}")
MLX_BASE_URL = f"http://localhost:{self.port}" url = f"http://{self.host}:{self.port}/v1/chat/completions"
MODEL = self.model headers = {"Content-Type": "application/json"}
# Convert OpenAIChatMessage objects to dictionaries # Extract parameters from the request body
messages_dict = [{"role": message.role, "content": message.content} for message in messages] max_tokens = body.get("max_tokens", 1024)
temperature = body.get("temperature", 0.8)
# Extract additional parameters from the body repeat_penalty = body.get("repeat_penalty", 1.0)
temperature = body.get("temperature", 1.0)
max_tokens = body.get("max_tokens", 100)
top_p = body.get("top_p", 1.0)
repetition_penalty = body.get("repetition_penalty", 1.0)
stop = self.stop_sequences
if "user" in body:
print("######################################")
print(f'# User: {body["user"]["name"]} ({body["user"]["id"]})')
print(f"# Message: {user_message}")
print("######################################")
payload = { payload = {
"model": MODEL, "messages": messages,
"messages": messages_dict,
"temperature": temperature,
"max_tokens": max_tokens, "max_tokens": max_tokens,
"top_p": top_p, "temperature": temperature,
"repetition_penalty": repetition_penalty, "repetition_penalty": repeat_penalty,
"stop": stop "stop": self.stop_sequence,
"stream": body.get("stream", False)
} }
try: try:
r = requests.post( r = requests.post(url, headers=headers, json=payload, stream=body.get("stream", False))
url=f"{MLX_BASE_URL}/v1/chat/completions",
json=payload,
stream=True,
)
r.raise_for_status() r.raise_for_status()
return r.iter_lines() if body.get("stream", False):
return r.iter_lines()
else:
return r.json()
except Exception as e: except Exception as e:
return f"Error: {e}" return f"Error: {e}"