diff --git a/pipelines/examples/mlx_pipeline.py b/pipelines/examples/mlx_pipeline.py index 2e27145..04e43be 100644 --- a/pipelines/examples/mlx_pipeline.py +++ b/pipelines/examples/mlx_pipeline.py @@ -23,6 +23,7 @@ class Pipeline: self.process = None self.model = os.getenv('MLX_MODEL', 'mistralai/Mistral-7B-Instruct-v0.2') # Default model if not set in environment variable self.port = self.find_free_port() + self.stop_sequences = os.getenv('MLX_STOP', None) # Stop sequences from environment variable @staticmethod def find_free_port(): @@ -70,7 +71,14 @@ class Pipeline: print(f"get_response:{__name__}") MLX_BASE_URL = f"http://localhost:{self.port}" - MODEL = "llama3" + MODEL = self.model + + # Extract additional parameters from the body + temperature = body.get("temperature", 1.0) + max_tokens = body.get("max_tokens", 100) + top_p = body.get("top_p", 1.0) + repetition_penalty = body.get("repetition_penalty", 1.0) + stop = self.stop_sequences if "user" in body: print("######################################") @@ -78,18 +86,26 @@ class Pipeline: print(f"# Message: {user_message}") print("######################################") + payload = { + "model": MODEL, + "messages": messages, + "temperature": temperature, + "max_tokens": max_tokens, + "top_p": top_p, + "repetition_penalty": repetition_penalty, + "stop": stop, + "stream": True # Always stream responses + } + try: r = requests.post( url=f"{MLX_BASE_URL}/v1/chat/completions", - json={**body, "model": MODEL}, + json=payload, stream=True, ) r.raise_for_status() - if body["stream"]: - return r.iter_lines() - else: - return r.json() + return r.iter_lines() except Exception as e: return f"Error: {e}" \ No newline at end of file