Fix: more robust MLX model switching

This commit is contained in:
Justin Hayes 2024-07-01 10:59:31 -04:00 committed by GitHub
parent 7743a40b41
commit 72e933cd6b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -9,6 +9,7 @@ requirements: requests, mlx-lm, huggingface-hub, psutil
environment_variables: MLX_HOST, MLX_PORT, MLX_SUBPROCESS environment_variables: MLX_HOST, MLX_PORT, MLX_SUBPROCESS
""" """
import argparse
from typing import List, Union, Generator, Iterator from typing import List, Union, Generator, Iterator
from schemas import OpenAIChatMessage from schemas import OpenAIChatMessage
from pydantic import BaseModel from pydantic import BaseModel
@ -19,13 +20,16 @@ import logging
from huggingface_hub import login from huggingface_hub import login
import time import time
import psutil import psutil
import json
class Pipeline: class Pipeline:
class Valves(BaseModel): class Valves(BaseModel):
MLX_STOP: str = "[INST]" MLX_STOP: str = "<|start_header_id|>,<|end_header_id|>,<|eot_id|>"
HUGGINGFACE_TOKEN: str = "" HUGGINGFACE_TOKEN: str = ""
MLX_MODEL_PATTERN: str = "mistralai" MLX_MODEL_PATTERN: str = "meta-llama"
MLX_DEFAULT_MODEL: str = "mistralai/Mistral-7B-Instruct-v0.3" MLX_DEFAULT_MODEL: str = "meta-llama/Meta-Llama-3-8B-Instruct"
MLX_CHAT_TEMPLATE: str = ""
MLX_USE_DEFAULT_CHAT_TEMPLATE: bool = False
def __init__(self): def __init__(self):
self.type = "manifold" self.type = "manifold"
@ -101,9 +105,20 @@ class Pipeline:
if not os.getenv("MLX_PORT"): if not os.getenv("MLX_PORT"):
self.port = self.find_free_port() self.port = self.find_free_port()
command = f"mlx_lm.server --model {model_name} --port {self.port}"
logging.info(f"Starting MLX server with command: {command}") command = [
self.server_process = subprocess.Popen(command, shell=True) "mlx_lm.server",
"--model", model_name,
"--port", str(self.port),
]
if self.valves.MLX_CHAT_TEMPLATE:
command.extend(["--chat-template", self.valves.MLX_CHAT_TEMPLATE])
elif self.valves.MLX_USE_DEFAULT_CHAT_TEMPLATE:
command.append("--use-default-chat-template")
logging.info(f"Starting MLX server with command: {' '.join(command)}")
self.server_process = subprocess.Popen(command)
self.current_model = model_id self.current_model = model_id
logging.info(f"Started MLX server for model {model_name} on port {self.port}") logging.info(f"Started MLX server for model {model_name} on port {self.port}")
time.sleep(5) # Give the server some time to start up time.sleep(5) # Give the server some time to start up