Merge pull request #16 from justinh-rahb/mlx-lm

Fix MLX Pipeline
This commit is contained in:
Timothy Jaeryang Baek 2024-05-27 10:53:28 -07:00 committed by GitHub
commit c0d4b828a0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,118 +1,99 @@
""" """
title: MLX Pipeline title: MLX Pipeline
author: justinh-rahb author: justinh-rahb
date: 2024-05-22 date: 2024-05-27
version: 1.0 version: 1.1
license: MIT license: MIT
description: A pipeline for running the mlx-lm server with a specified model. description: A pipeline for generating text using Apple MLX Framework.
dependencies: requests, mlx-lm dependencies: requests, mlx-lm, huggingface-hub
environment_variables: MLX_MODEL environment_variables: MLX_HOST, MLX_PORT, MLX_MODEL, MLX_STOP, MLX_SUBPROCESS, HUGGINGFACE_TOKEN
""" """
from typing import List, Union, Generator, Iterator from typing import List, Union, Generator, Iterator
import requests
import subprocess
import os
import socket
from schemas import OpenAIChatMessage from schemas import OpenAIChatMessage
import requests
import os
import subprocess
import logging
from huggingface_hub import login
class Pipeline: class Pipeline:
def __init__(self): def __init__(self):
# Optionally, you can set the id and name of the pipeline.
self.id = "mlx_pipeline" self.id = "mlx_pipeline"
self.name = "MLX Pipeline" self.name = "MLX Pipeline"
self.process = None self.host = os.getenv("MLX_HOST", "localhost")
self.model = os.getenv( self.port = os.getenv("MLX_PORT", "8080")
"MLX_MODEL", "mistralai/Mistral-7B-Instruct-v0.2" self.model = os.getenv("MLX_MODEL", "mistralai/Mistral-7B-Instruct-v0.2")
) # Default model if not set in environment variable self.stop_sequence = os.getenv("MLX_STOP", "[INST]").split(",") # Default stop sequence is [INST]
self.port = self.find_free_port() self.subprocess = os.getenv("MLX_SUBPROCESS", "true").lower() == "true"
self.stop_sequences = os.getenv( self.huggingface_token = os.getenv("HUGGINGFACE_TOKEN", None)
"MLX_STOP", "[INST]"
) # Stop sequences from environment variable
@staticmethod if self.huggingface_token:
def find_free_port(): login(self.huggingface_token)
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0)) if self.subprocess:
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) self.start_mlx_server()
return s.getsockname()[1]
def start_mlx_server(self):
if not os.getenv("MLX_PORT"):
self.port = self.find_free_port()
command = f"mlx_lm.server --model {self.model} --port {self.port}"
self.server_process = subprocess.Popen(command, shell=True)
logging.info(f"Started MLX server on port {self.port}")
def find_free_port(self):
import socket
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.bind(('', 0))
port = s.getsockname()[1]
s.close()
return port
async def on_startup(self): async def on_startup(self):
# This function is called when the server is started. logging.info(f"on_startup:{__name__}")
print(f"on_startup:{__name__}")
self.start_subprocess()
async def on_shutdown(self): async def on_shutdown(self):
# This function is called when the server is stopped. if self.subprocess and hasattr(self, 'server_process'):
print(f"on_shutdown:{__name__}") self.server_process.terminate()
self.stop_subprocess() logging.info(f"Terminated MLX server on port {self.port}")
def start_subprocess(self):
# Start the subprocess for "mlx_lm.server --model ${MLX_MODEL} --port ${PORT}"
try:
self.process = subprocess.Popen(
["mlx_lm.server", "--model", self.model, "--port", str(self.port)],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
print(
f"Subprocess started with PID: {self.process.pid} on port {self.port}"
)
except Exception as e:
print(f"Failed to start subprocess: {e}")
def stop_subprocess(self):
# Stop the subprocess if it is running
if self.process:
try:
self.process.terminate()
self.process.wait()
print(f"Subprocess with PID {self.process.pid} terminated")
except Exception as e:
print(f"Failed to terminate subprocess: {e}")
def get_response( def get_response(
self, user_message: str, model_id: str, messages: List[dict], body: dict self, user_message: str, model_id: str, messages: List[dict], body: dict
) -> Union[str, Generator, Iterator]: ) -> Union[str, Generator, Iterator]:
# This is where you can add your custom pipelines like RAG.' logging.info(f"get_response:{__name__}")
print(f"get_response:{__name__}")
MLX_BASE_URL = f"http://localhost:{self.port}" url = f"http://{self.host}:{self.port}/v1/chat/completions"
MODEL = self.model headers = {"Content-Type": "application/json"}
# Extract and validate parameters from the request body
max_tokens = body.get("max_tokens", 1024)
if not isinstance(max_tokens, int) or max_tokens < 0:
max_tokens = 1024 # Default to 1024 if invalid
# Extract additional parameters from the body
temperature = body.get("temperature", 0.8) temperature = body.get("temperature", 0.8)
max_tokens = body.get("max_tokens", 1000) if not isinstance(temperature, (int, float)) or temperature < 0:
top_p = body.get("top_p", 1.0) temperature = 0.8 # Default to 0.8 if invalid
repetition_penalty = body.get("repetition_penalty", 1.0)
stop = self.stop_sequences
if "user" in body: repeat_penalty = body.get("repeat_penalty", 1.0)
print("######################################") if not isinstance(repeat_penalty, (int, float)) or repeat_penalty < 0:
print(f'# User: {body["user"]["name"]} ({body["user"]["id"]})') repeat_penalty = 1.0 # Default to 1.0 if invalid
print(f"# Message: {user_message}")
print("######################################")
payload = { payload = {
"model": MODEL, "messages": messages,
"messages": [message.model_dump() for message in messages],
"temperature": temperature,
"max_tokens": max_tokens, "max_tokens": max_tokens,
"top_p": top_p, "temperature": temperature,
"repetition_penalty": repetition_penalty, "repetition_penalty": repeat_penalty,
"stop": stop, "stop": self.stop_sequence,
"stream": body.get("stream", False)
} }
try: try:
r = requests.post( r = requests.post(url, headers=headers, json=payload, stream=body.get("stream", False))
url=f"{MLX_BASE_URL}/v1/chat/completions",
json=payload,
stream=True,
)
r.raise_for_status() r.raise_for_status()
return r.iter_lines() if body.get("stream", False):
return r.iter_lines()
else:
return r.json()
except Exception as e: except Exception as e:
return f"Error: {e}" return f"Error: {e}"