This commit is contained in:
Justin Hayes 2024-05-22 15:10:51 -04:00
parent 7268268f0f
commit cae961c2d4

View File

@ -5,15 +5,17 @@ date: 2024-05-22
version: 1.0 version: 1.0
license: MIT license: MIT
description: A pipeline for running the mlx-lm server with a specified model. description: A pipeline for running the mlx-lm server with a specified model.
dependencies: requests, mlx-lm dependencies: requests, mlx-lm, huggingface_hub
environment_variables: MLX_MODEL environment_variables: MLX_MODEL, MLX_STOP, HUGGINGFACE_TOKEN
""" """
from typing import List, Union, Generator, Iterator from typing import List, Union, Generator, Iterator
import requests
import subprocess import subprocess
import os import os
import socket import socket
import time
import requests
from huggingface_hub import login
from schemas import OpenAIChatMessage from schemas import OpenAIChatMessage
@ -23,21 +25,30 @@ class Pipeline:
self.id = "mlx_pipeline" self.id = "mlx_pipeline"
self.name = "MLX Pipeline" self.name = "MLX Pipeline"
self.process = None self.process = None
self.model = os.getenv( self.model = os.getenv('MLX_MODEL', 'mistralai/Mistral-7B-Instruct-v0.2') # Default model if not set in environment variable
"MLX_MODEL", "mistralai/Mistral-7B-Instruct-v0.2"
) # Default model if not set in environment variable
self.port = self.find_free_port() self.port = self.find_free_port()
self.stop_sequences = os.getenv( self.stop_sequences = os.getenv('MLX_STOP', '[INST]') # Stop sequences from environment variable
"MLX_STOP", "[INST]" self.hf_token = os.getenv('HUGGINGFACE_TOKEN', None) # Hugging Face token from environment variable
) # Stop sequences from environment variable
# Authenticate with Hugging Face if a token is provided
if self.hf_token:
self.authenticate_huggingface(self.hf_token)
@staticmethod @staticmethod
def find_free_port(): def find_free_port():
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0)) s.bind(('', 0))
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
return s.getsockname()[1] return s.getsockname()[1]
@staticmethod
def authenticate_huggingface(token: str):
try:
login(token)
print("Successfully authenticated with Hugging Face.")
except Exception as e:
print(f"Failed to authenticate with Hugging Face: {e}")
async def on_startup(self): async def on_startup(self):
# This function is called when the server is started. # This function is called when the server is started.
print(f"on_startup:{__name__}") print(f"on_startup:{__name__}")
@ -54,13 +65,18 @@ class Pipeline:
self.process = subprocess.Popen( self.process = subprocess.Popen(
["mlx_lm.server", "--model", self.model, "--port", str(self.port)], ["mlx_lm.server", "--model", self.model, "--port", str(self.port)],
stdout=subprocess.PIPE, stdout=subprocess.PIPE,
stderr=subprocess.PIPE, stderr=subprocess.PIPE
)
print(
f"Subprocess started with PID: {self.process.pid} on port {self.port}"
) )
print(f"Subprocess started with PID: {self.process.pid} on port {self.port}")
# Check if the process has started correctly
time.sleep(2) # Give it a moment to start
if self.process.poll() is not None:
raise RuntimeError(f"Subprocess failed to start. Return code: {self.process.returncode}")
except Exception as e: except Exception as e:
print(f"Failed to start subprocess: {e}") print(f"Failed to start subprocess: {e}")
self.process = None
def stop_subprocess(self): def stop_subprocess(self):
# Stop the subprocess if it is running # Stop the subprocess if it is running
@ -71,6 +87,8 @@ class Pipeline:
print(f"Subprocess with PID {self.process.pid} terminated") print(f"Subprocess with PID {self.process.pid} terminated")
except Exception as e: except Exception as e:
print(f"Failed to terminate subprocess: {e}") print(f"Failed to terminate subprocess: {e}")
finally:
self.process = None
def get_response( def get_response(
self, user_message: str, messages: List[OpenAIChatMessage], body: dict self, user_message: str, messages: List[OpenAIChatMessage], body: dict
@ -78,9 +96,15 @@ class Pipeline:
# This is where you can add your custom pipelines like RAG.' # This is where you can add your custom pipelines like RAG.'
print(f"get_response:{__name__}") print(f"get_response:{__name__}")
if not self.process or self.process.poll() is not None:
return "Error: Subprocess is not running."
MLX_BASE_URL = f"http://localhost:{self.port}" MLX_BASE_URL = f"http://localhost:{self.port}"
MODEL = self.model MODEL = self.model
# Convert OpenAIChatMessage objects to dictionaries
messages_dict = [{"role": message.role, "content": message.content} for message in messages]
# Extract additional parameters from the body # Extract additional parameters from the body
temperature = body.get("temperature", 0.8) temperature = body.get("temperature", 0.8)
max_tokens = body.get("max_tokens", 1000) max_tokens = body.get("max_tokens", 1000)
@ -96,12 +120,12 @@ class Pipeline:
payload = { payload = {
"model": MODEL, "model": MODEL,
"messages": [message.model_dump() for message in messages], "messages": messages_dict,
"temperature": temperature, "temperature": temperature,
"max_tokens": max_tokens, "max_tokens": max_tokens,
"top_p": top_p, "top_p": top_p,
"repetition_penalty": repetition_penalty, "repetition_penalty": repetition_penalty,
"stop": stop, "stop": stop
} }
try: try: