diff --git a/pipelines/examples/mlx_pipeline.py b/pipelines/examples/mlx_pipeline.py index 09b2ace..85b7b30 100644 --- a/pipelines/examples/mlx_pipeline.py +++ b/pipelines/examples/mlx_pipeline.py @@ -65,10 +65,18 @@ class Pipeline: url = f"http://{self.host}:{self.port}/v1/chat/completions" headers = {"Content-Type": "application/json"} - # Extract parameters from the request body + # Extract and validate parameters from the request body max_tokens = body.get("max_tokens", 1024) + if not isinstance(max_tokens, int) or max_tokens < 0: + max_tokens = 1024 # Default to 1024 if invalid + temperature = body.get("temperature", 0.8) + if not isinstance(temperature, (int, float)) or temperature < 0: + temperature = 0.8 # Default to 0.8 if invalid + repeat_penalty = body.get("repeat_penalty", 1.0) + if not isinstance(repeat_penalty, (int, float)) or repeat_penalty < 0: + repeat_penalty = 1.0 # Default to 1.0 if invalid payload = { "messages": messages,