This commit is contained in:
Justin Hayes 2024-05-22 15:10:51 -04:00
parent 7268268f0f
commit cae961c2d4

View File

@ -5,15 +5,17 @@ date: 2024-05-22
version: 1.0
license: MIT
description: A pipeline for running the mlx-lm server with a specified model.
dependencies: requests, mlx-lm
environment_variables: MLX_MODEL
dependencies: requests, mlx-lm, huggingface_hub
environment_variables: MLX_MODEL, MLX_STOP, HUGGINGFACE_TOKEN
"""
from typing import List, Union, Generator, Iterator
import requests
import subprocess
import os
import socket
import time
import requests
from huggingface_hub import login
from schemas import OpenAIChatMessage
@ -23,21 +25,30 @@ class Pipeline:
self.id = "mlx_pipeline"
self.name = "MLX Pipeline"
self.process = None
self.model = os.getenv(
"MLX_MODEL", "mistralai/Mistral-7B-Instruct-v0.2"
) # Default model if not set in environment variable
self.model = os.getenv('MLX_MODEL', 'mistralai/Mistral-7B-Instruct-v0.2') # Default model if not set in environment variable
self.port = self.find_free_port()
self.stop_sequences = os.getenv(
"MLX_STOP", "[INST]"
) # Stop sequences from environment variable
self.stop_sequences = os.getenv('MLX_STOP', '[INST]') # Stop sequences from environment variable
self.hf_token = os.getenv('HUGGINGFACE_TOKEN', None) # Hugging Face token from environment variable
# Authenticate with Hugging Face if a token is provided
if self.hf_token:
self.authenticate_huggingface(self.hf_token)
@staticmethod
def find_free_port():
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0))
s.bind(('', 0))
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
return s.getsockname()[1]
@staticmethod
def authenticate_huggingface(token: str):
try:
login(token)
print("Successfully authenticated with Hugging Face.")
except Exception as e:
print(f"Failed to authenticate with Hugging Face: {e}")
async def on_startup(self):
# This function is called when the server is started.
print(f"on_startup:{__name__}")
@ -54,13 +65,18 @@ class Pipeline:
self.process = subprocess.Popen(
["mlx_lm.server", "--model", self.model, "--port", str(self.port)],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
print(
f"Subprocess started with PID: {self.process.pid} on port {self.port}"
stderr=subprocess.PIPE
)
print(f"Subprocess started with PID: {self.process.pid} on port {self.port}")
# Check if the process has started correctly
time.sleep(2) # Give it a moment to start
if self.process.poll() is not None:
raise RuntimeError(f"Subprocess failed to start. Return code: {self.process.returncode}")
except Exception as e:
print(f"Failed to start subprocess: {e}")
self.process = None
def stop_subprocess(self):
# Stop the subprocess if it is running
@ -71,6 +87,8 @@ class Pipeline:
print(f"Subprocess with PID {self.process.pid} terminated")
except Exception as e:
print(f"Failed to terminate subprocess: {e}")
finally:
self.process = None
def get_response(
self, user_message: str, messages: List[OpenAIChatMessage], body: dict
@ -78,9 +96,15 @@ class Pipeline:
# This is where you can add your custom pipelines like RAG.'
print(f"get_response:{__name__}")
if not self.process or self.process.poll() is not None:
return "Error: Subprocess is not running."
MLX_BASE_URL = f"http://localhost:{self.port}"
MODEL = self.model
# Convert OpenAIChatMessage objects to dictionaries
messages_dict = [{"role": message.role, "content": message.content} for message in messages]
# Extract additional parameters from the body
temperature = body.get("temperature", 0.8)
max_tokens = body.get("max_tokens", 1000)
@ -96,12 +120,12 @@ class Pipeline:
payload = {
"model": MODEL,
"messages": [message.model_dump() for message in messages],
"messages": messages_dict,
"temperature": temperature,
"max_tokens": max_tokens,
"top_p": top_p,
"repetition_penalty": repetition_penalty,
"stop": stop,
"stop": stop
}
try:
@ -115,4 +139,4 @@ class Pipeline:
return r.iter_lines()
except Exception as e:
return f"Error: {e}"
return f"Error: {e}"