Merge pull request #16 from justinh-rahb/mlx-lm

Fix MLX Pipeline
This commit is contained in:
Timothy Jaeryang Baek 2024-05-27 10:53:28 -07:00 committed by GitHub
commit c0d4b828a0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,118 +1,99 @@
"""
title: MLX Pipeline
author: justinh-rahb
date: 2024-05-22
version: 1.0
date: 2024-05-27
version: 1.1
license: MIT
description: A pipeline for running the mlx-lm server with a specified model.
dependencies: requests, mlx-lm
environment_variables: MLX_MODEL
description: A pipeline for generating text using Apple MLX Framework.
dependencies: requests, mlx-lm, huggingface-hub
environment_variables: MLX_HOST, MLX_PORT, MLX_MODEL, MLX_STOP, MLX_SUBPROCESS, HUGGINGFACE_TOKEN
"""
from typing import List, Union, Generator, Iterator
import requests
import subprocess
import os
import socket
from schemas import OpenAIChatMessage
import requests
import os
import subprocess
import logging
from huggingface_hub import login
class Pipeline:
def __init__(self):
# Optionally, you can set the id and name of the pipeline.
self.id = "mlx_pipeline"
self.name = "MLX Pipeline"
self.process = None
self.model = os.getenv(
"MLX_MODEL", "mistralai/Mistral-7B-Instruct-v0.2"
) # Default model if not set in environment variable
self.port = self.find_free_port()
self.stop_sequences = os.getenv(
"MLX_STOP", "[INST]"
) # Stop sequences from environment variable
self.host = os.getenv("MLX_HOST", "localhost")
self.port = os.getenv("MLX_PORT", "8080")
self.model = os.getenv("MLX_MODEL", "mistralai/Mistral-7B-Instruct-v0.2")
self.stop_sequence = os.getenv("MLX_STOP", "[INST]").split(",") # Default stop sequence is [INST]
self.subprocess = os.getenv("MLX_SUBPROCESS", "true").lower() == "true"
self.huggingface_token = os.getenv("HUGGINGFACE_TOKEN", None)
@staticmethod
def find_free_port():
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0))
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
return s.getsockname()[1]
if self.huggingface_token:
login(self.huggingface_token)
if self.subprocess:
self.start_mlx_server()
def start_mlx_server(self):
if not os.getenv("MLX_PORT"):
self.port = self.find_free_port()
command = f"mlx_lm.server --model {self.model} --port {self.port}"
self.server_process = subprocess.Popen(command, shell=True)
logging.info(f"Started MLX server on port {self.port}")
def find_free_port(self):
import socket
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.bind(('', 0))
port = s.getsockname()[1]
s.close()
return port
async def on_startup(self):
# This function is called when the server is started.
print(f"on_startup:{__name__}")
self.start_subprocess()
logging.info(f"on_startup:{__name__}")
async def on_shutdown(self):
# This function is called when the server is stopped.
print(f"on_shutdown:{__name__}")
self.stop_subprocess()
def start_subprocess(self):
# Start the subprocess for "mlx_lm.server --model ${MLX_MODEL} --port ${PORT}"
try:
self.process = subprocess.Popen(
["mlx_lm.server", "--model", self.model, "--port", str(self.port)],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
print(
f"Subprocess started with PID: {self.process.pid} on port {self.port}"
)
except Exception as e:
print(f"Failed to start subprocess: {e}")
def stop_subprocess(self):
# Stop the subprocess if it is running
if self.process:
try:
self.process.terminate()
self.process.wait()
print(f"Subprocess with PID {self.process.pid} terminated")
except Exception as e:
print(f"Failed to terminate subprocess: {e}")
if self.subprocess and hasattr(self, 'server_process'):
self.server_process.terminate()
logging.info(f"Terminated MLX server on port {self.port}")
def get_response(
self, user_message: str, model_id: str, messages: List[dict], body: dict
) -> Union[str, Generator, Iterator]:
# This is where you can add your custom pipelines like RAG.'
print(f"get_response:{__name__}")
logging.info(f"get_response:{__name__}")
MLX_BASE_URL = f"http://localhost:{self.port}"
MODEL = self.model
url = f"http://{self.host}:{self.port}/v1/chat/completions"
headers = {"Content-Type": "application/json"}
# Extract and validate parameters from the request body
max_tokens = body.get("max_tokens", 1024)
if not isinstance(max_tokens, int) or max_tokens < 0:
max_tokens = 1024 # Default to 1024 if invalid
# Extract additional parameters from the body
temperature = body.get("temperature", 0.8)
max_tokens = body.get("max_tokens", 1000)
top_p = body.get("top_p", 1.0)
repetition_penalty = body.get("repetition_penalty", 1.0)
stop = self.stop_sequences
if not isinstance(temperature, (int, float)) or temperature < 0:
temperature = 0.8 # Default to 0.8 if invalid
if "user" in body:
print("######################################")
print(f'# User: {body["user"]["name"]} ({body["user"]["id"]})')
print(f"# Message: {user_message}")
print("######################################")
repeat_penalty = body.get("repeat_penalty", 1.0)
if not isinstance(repeat_penalty, (int, float)) or repeat_penalty < 0:
repeat_penalty = 1.0 # Default to 1.0 if invalid
payload = {
"model": MODEL,
"messages": [message.model_dump() for message in messages],
"temperature": temperature,
"messages": messages,
"max_tokens": max_tokens,
"top_p": top_p,
"repetition_penalty": repetition_penalty,
"stop": stop,
"temperature": temperature,
"repetition_penalty": repeat_penalty,
"stop": self.stop_sequence,
"stream": body.get("stream", False)
}
try:
r = requests.post(
url=f"{MLX_BASE_URL}/v1/chat/completions",
json=payload,
stream=True,
)
r = requests.post(url, headers=headers, json=payload, stream=body.get("stream", False))
r.raise_for_status()
return r.iter_lines()
if body.get("stream", False):
return r.iter_lines()
else:
return r.json()
except Exception as e:
return f"Error: {e}"