From 9d9712da98a08a88f9d8c8faf4409a14ba4bd46d Mon Sep 17 00:00:00 2001 From: Justin Hayes Date: Wed, 22 May 2024 10:56:20 -0400 Subject: [PATCH 1/9] Add MLX-LM Server example --- pipelines/mlx_llm.py | 86 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 86 insertions(+) create mode 100644 pipelines/mlx_llm.py diff --git a/pipelines/mlx_llm.py b/pipelines/mlx_llm.py new file mode 100644 index 0000000..553df1f --- /dev/null +++ b/pipelines/mlx_llm.py @@ -0,0 +1,86 @@ +from typing import List, Union, Generator, Iterator +import requests +import subprocess +import os +import socket +from schemas import OpenAIChatMessage + + +class Pipeline: + def __init__(self): + # Optionally, you can set the id and name of the pipeline. + self.id = "mlx_pipeline" + self.name = "MLX Pipeline" + self.process = None + self.model = os.getenv('MLX_MODEL', 'mistralai/Mistral-7B-Instruct-v0.2') # Default model if not set in environment variable + self.port = self.find_free_port() + + @staticmethod + def find_free_port(): + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(('', 0)) + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + return s.getsockname()[1] + + async def on_startup(self): + # This function is called when the server is started. + print(f"on_startup:{__name__}") + self.start_subprocess() + + async def on_shutdown(self): + # This function is called when the server is stopped. + print(f"on_shutdown:{__name__}") + self.stop_subprocess() + + def start_subprocess(self): + # Start the subprocess for "mlx_lm.server --model ${MLX_MODEL} --port ${PORT}" + try: + self.process = subprocess.Popen( + ["mlx_lm.server", "--model", self.model, "--port", str(self.port)], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE + ) + print(f"Subprocess started with PID: {self.process.pid} on port {self.port}") + except Exception as e: + print(f"Failed to start subprocess: {e}") + + def stop_subprocess(self): + # Stop the subprocess if it is running + if self.process: + try: + self.process.terminate() + self.process.wait() + print(f"Subprocess with PID {self.process.pid} terminated") + except Exception as e: + print(f"Failed to terminate subprocess: {e}") + + def get_response( + self, user_message: str, messages: List[OpenAIChatMessage], body: dict + ) -> Union[str, Generator, Iterator]: + # This is where you can add your custom pipelines like RAG.' + print(f"get_response:{__name__}") + + MLX_BASE_URL = f"http://localhost:{self.port}" + MODEL = "llama3" + + if "user" in body: + print("######################################") + print(f'# User: {body["user"]["name"]} ({body["user"]["id"]})') + print(f"# Message: {user_message}") + print("######################################") + + try: + r = requests.post( + url=f"{MLX_BASE_URL}/v1/chat/completions", + json={**body, "model": MODEL}, + stream=True, + ) + + r.raise_for_status() + + if body["stream"]: + return r.iter_lines() + else: + return r.json() + except Exception as e: + return f"Error: {e}" \ No newline at end of file From 9f4803a9b651bcd31682eccde6a2a0d7ee35764f Mon Sep 17 00:00:00 2001 From: Justin Hayes Date: Wed, 22 May 2024 11:19:29 -0400 Subject: [PATCH 2/9] Move to examples --- pipelines/{mlx_llm.py => examples/mlx_pipeline.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename pipelines/{mlx_llm.py => examples/mlx_pipeline.py} (100%) diff --git a/pipelines/mlx_llm.py b/pipelines/examples/mlx_pipeline.py similarity index 100% rename from pipelines/mlx_llm.py rename to pipelines/examples/mlx_pipeline.py From e263648f952896be72c85609bf672047e752b867 Mon Sep 17 00:00:00 2001 From: Justin Hayes Date: Wed, 22 May 2024 11:19:50 -0400 Subject: [PATCH 3/9] Add metadata header --- pipelines/examples/mlx_pipeline.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/pipelines/examples/mlx_pipeline.py b/pipelines/examples/mlx_pipeline.py index 553df1f..da7fea5 100644 --- a/pipelines/examples/mlx_pipeline.py +++ b/pipelines/examples/mlx_pipeline.py @@ -1,3 +1,12 @@ +""" +Name: MLX Pipeline +Description: A pipeline for running the mlx-lm server with a specified model. +Author: justinh-rahb +License: MIT +Python Dependencies: requests, subprocess, os, socket, schemas +Environment Variables: MLX_MODEL +""" + from typing import List, Union, Generator, Iterator import requests import subprocess From 830ae49f09e22c169e2295843b95ff58115a4ad9 Mon Sep 17 00:00:00 2001 From: Justin Hayes Date: Wed, 22 May 2024 11:20:43 -0400 Subject: [PATCH 4/9] Fix depends --- pipelines/examples/mlx_pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pipelines/examples/mlx_pipeline.py b/pipelines/examples/mlx_pipeline.py index da7fea5..2e27145 100644 --- a/pipelines/examples/mlx_pipeline.py +++ b/pipelines/examples/mlx_pipeline.py @@ -3,7 +3,7 @@ Name: MLX Pipeline Description: A pipeline for running the mlx-lm server with a specified model. Author: justinh-rahb License: MIT -Python Dependencies: requests, subprocess, os, socket, schemas +Python Dependencies: requests, mlx-lm Environment Variables: MLX_MODEL """ From 940d91c21650a05906bbd33faa1ea1e2d724728b Mon Sep 17 00:00:00 2001 From: Justin Hayes Date: Wed, 22 May 2024 11:59:36 -0400 Subject: [PATCH 5/9] Handle more params --- pipelines/examples/mlx_pipeline.py | 28 ++++++++++++++++++++++------ 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/pipelines/examples/mlx_pipeline.py b/pipelines/examples/mlx_pipeline.py index 2e27145..04e43be 100644 --- a/pipelines/examples/mlx_pipeline.py +++ b/pipelines/examples/mlx_pipeline.py @@ -23,6 +23,7 @@ class Pipeline: self.process = None self.model = os.getenv('MLX_MODEL', 'mistralai/Mistral-7B-Instruct-v0.2') # Default model if not set in environment variable self.port = self.find_free_port() + self.stop_sequences = os.getenv('MLX_STOP', None) # Stop sequences from environment variable @staticmethod def find_free_port(): @@ -70,7 +71,14 @@ class Pipeline: print(f"get_response:{__name__}") MLX_BASE_URL = f"http://localhost:{self.port}" - MODEL = "llama3" + MODEL = self.model + + # Extract additional parameters from the body + temperature = body.get("temperature", 1.0) + max_tokens = body.get("max_tokens", 100) + top_p = body.get("top_p", 1.0) + repetition_penalty = body.get("repetition_penalty", 1.0) + stop = self.stop_sequences if "user" in body: print("######################################") @@ -78,18 +86,26 @@ class Pipeline: print(f"# Message: {user_message}") print("######################################") + payload = { + "model": MODEL, + "messages": messages, + "temperature": temperature, + "max_tokens": max_tokens, + "top_p": top_p, + "repetition_penalty": repetition_penalty, + "stop": stop, + "stream": True # Always stream responses + } + try: r = requests.post( url=f"{MLX_BASE_URL}/v1/chat/completions", - json={**body, "model": MODEL}, + json=payload, stream=True, ) r.raise_for_status() - if body["stream"]: - return r.iter_lines() - else: - return r.json() + return r.iter_lines() except Exception as e: return f"Error: {e}" \ No newline at end of file From 46f4aa1ca538a5e215a5c21db8c3085f354a73e4 Mon Sep 17 00:00:00 2001 From: Justin Hayes Date: Wed, 22 May 2024 12:01:47 -0400 Subject: [PATCH 6/9] Adjust default params --- pipelines/examples/mlx_pipeline.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pipelines/examples/mlx_pipeline.py b/pipelines/examples/mlx_pipeline.py index 04e43be..9414e9b 100644 --- a/pipelines/examples/mlx_pipeline.py +++ b/pipelines/examples/mlx_pipeline.py @@ -23,7 +23,7 @@ class Pipeline: self.process = None self.model = os.getenv('MLX_MODEL', 'mistralai/Mistral-7B-Instruct-v0.2') # Default model if not set in environment variable self.port = self.find_free_port() - self.stop_sequences = os.getenv('MLX_STOP', None) # Stop sequences from environment variable + self.stop_sequences = os.getenv('MLX_STOP', '[INST]') # Stop sequences from environment variable @staticmethod def find_free_port(): @@ -74,8 +74,8 @@ class Pipeline: MODEL = self.model # Extract additional parameters from the body - temperature = body.get("temperature", 1.0) - max_tokens = body.get("max_tokens", 100) + temperature = body.get("temperature", 0.8) + max_tokens = body.get("max_tokens", 1000) top_p = body.get("top_p", 1.0) repetition_penalty = body.get("repetition_penalty", 1.0) stop = self.stop_sequences From 45cba3dc5ddc055c0bb3465a79c7b34ece8c8a63 Mon Sep 17 00:00:00 2001 From: Justin Hayes Date: Wed, 22 May 2024 12:07:08 -0400 Subject: [PATCH 7/9] Fix --- pipelines/examples/mlx_pipeline.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pipelines/examples/mlx_pipeline.py b/pipelines/examples/mlx_pipeline.py index 9414e9b..86aa542 100644 --- a/pipelines/examples/mlx_pipeline.py +++ b/pipelines/examples/mlx_pipeline.py @@ -93,8 +93,7 @@ class Pipeline: "max_tokens": max_tokens, "top_p": top_p, "repetition_penalty": repetition_penalty, - "stop": stop, - "stream": True # Always stream responses + "stop": stop } try: From bd2ea926bba2700c933f6243d93ebbcbd26d0317 Mon Sep 17 00:00:00 2001 From: Justin Hayes Date: Wed, 22 May 2024 12:18:29 -0400 Subject: [PATCH 8/9] Refac meta --- pipelines/examples/mlx_pipeline.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/pipelines/examples/mlx_pipeline.py b/pipelines/examples/mlx_pipeline.py index 86aa542..f39d03a 100644 --- a/pipelines/examples/mlx_pipeline.py +++ b/pipelines/examples/mlx_pipeline.py @@ -1,10 +1,12 @@ """ -Name: MLX Pipeline -Description: A pipeline for running the mlx-lm server with a specified model. -Author: justinh-rahb -License: MIT -Python Dependencies: requests, mlx-lm -Environment Variables: MLX_MODEL +title: MLX Pipeline +author: justinh-rahb +date: 2024-05-22 +version: 1.0 +license: MIT +description: A pipeline for running the mlx-lm server with a specified model. +dependencies: requests, mlx-lm +environment: MLX_MODEL """ from typing import List, Union, Generator, Iterator From 7556800fdf08e9f6d983eb301d8ad586435d5673 Mon Sep 17 00:00:00 2001 From: Justin Hayes Date: Wed, 22 May 2024 12:22:39 -0400 Subject: [PATCH 9/9] Refac meta --- pipelines/examples/mlx_pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pipelines/examples/mlx_pipeline.py b/pipelines/examples/mlx_pipeline.py index f39d03a..68daaaa 100644 --- a/pipelines/examples/mlx_pipeline.py +++ b/pipelines/examples/mlx_pipeline.py @@ -6,7 +6,7 @@ version: 1.0 license: MIT description: A pipeline for running the mlx-lm server with a specified model. dependencies: requests, mlx-lm -environment: MLX_MODEL +environment_variables: MLX_MODEL """ from typing import List, Union, Generator, Iterator