diff --git a/examples/pipelines/providers/mlx_manifold_pipeline.py b/examples/pipelines/providers/mlx_manifold_pipeline.py index 6c8aa16..51181a5 100644 --- a/examples/pipelines/providers/mlx_manifold_pipeline.py +++ b/examples/pipelines/providers/mlx_manifold_pipeline.py @@ -9,6 +9,7 @@ requirements: requests, mlx-lm, huggingface-hub, psutil environment_variables: MLX_HOST, MLX_PORT, MLX_SUBPROCESS """ +import argparse from typing import List, Union, Generator, Iterator from schemas import OpenAIChatMessage from pydantic import BaseModel @@ -19,13 +20,16 @@ import logging from huggingface_hub import login import time import psutil +import json class Pipeline: class Valves(BaseModel): - MLX_STOP: str = "[INST]" + MLX_STOP: str = "<|start_header_id|>,<|end_header_id|>,<|eot_id|>" HUGGINGFACE_TOKEN: str = "" - MLX_MODEL_PATTERN: str = "mistralai" - MLX_DEFAULT_MODEL: str = "mistralai/Mistral-7B-Instruct-v0.3" + MLX_MODEL_PATTERN: str = "meta-llama" + MLX_DEFAULT_MODEL: str = "meta-llama/Meta-Llama-3-8B-Instruct" + MLX_CHAT_TEMPLATE: str = "" + MLX_USE_DEFAULT_CHAT_TEMPLATE: bool = False def __init__(self): self.type = "manifold" @@ -101,9 +105,20 @@ class Pipeline: if not os.getenv("MLX_PORT"): self.port = self.find_free_port() - command = f"mlx_lm.server --model {model_name} --port {self.port}" - logging.info(f"Starting MLX server with command: {command}") - self.server_process = subprocess.Popen(command, shell=True) + + command = [ + "mlx_lm.server", + "--model", model_name, + "--port", str(self.port), + ] + + if self.valves.MLX_CHAT_TEMPLATE: + command.extend(["--chat-template", self.valves.MLX_CHAT_TEMPLATE]) + elif self.valves.MLX_USE_DEFAULT_CHAT_TEMPLATE: + command.append("--use-default-chat-template") + + logging.info(f"Starting MLX server with command: {' '.join(command)}") + self.server_process = subprocess.Popen(command) self.current_model = model_id logging.info(f"Started MLX server for model {model_name} on port {self.port}") time.sleep(5) # Give the server some time to start up