"""
title: MLX Pipeline
author: justinh-rahb
date: 2024-05-27
version: 1.1
license: MIT
description: A pipeline for generating text using Apple MLX Framework.
requirements: requests, mlx-lm, huggingface-hub
environment_variables: MLX_HOST, MLX_PORT, MLX_MODEL, MLX_STOP, MLX_SUBPROCESS, HUGGINGFACE_TOKEN
"""

from typing import List, Union, Generator, Iterator
from schemas import OpenAIChatMessage
import requests
import os
import subprocess
import logging
from huggingface_hub import login


class Pipeline:
    def __init__(self):
        # Optionally, you can set the id and name of the pipeline.
        # Assign a unique identifier to the pipeline.
        # The identifier must be unique across all pipelines.
        # The identifier must be an alphanumeric string that can include underscores or hyphens. It cannot contain spaces, special characters, slashes, or backslashes.
        self.id = "mlx_pipeline"
        self.name = "MLX Pipeline"
        self.host = os.getenv("MLX_HOST", "localhost")
        self.port = os.getenv("MLX_PORT", "8080")
        self.model = os.getenv("MLX_MODEL", "mistralai/Mistral-7B-Instruct-v0.2")
        self.stop_sequence = os.getenv("MLX_STOP", "[INST]").split(
            ","
        )  # Default stop sequence is [INST]
        self.subprocess = os.getenv("MLX_SUBPROCESS", "true").lower() == "true"
        self.huggingface_token = os.getenv("HUGGINGFACE_TOKEN", None)

        if self.huggingface_token:
            login(self.huggingface_token)

        if self.subprocess:
            self.start_mlx_server()

    def start_mlx_server(self):
        if not os.getenv("MLX_PORT"):
            self.port = self.find_free_port()
            command = f"mlx_lm.server --model {self.model} --port {self.port}"
            self.server_process = subprocess.Popen(command, shell=True)
            logging.info(f"Started MLX server on port {self.port}")

    def find_free_port(self):
        import socket

        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        s.bind(("", 0))
        port = s.getsockname()[1]
        s.close()
        return port

    async def on_startup(self):
        logging.info(f"on_startup:{__name__}")

    async def on_shutdown(self):
        if self.subprocess and hasattr(self, "server_process"):
            self.server_process.terminate()
            logging.info(f"Terminated MLX server on port {self.port}")

    def pipe(
        self, user_message: str, model_id: str, messages: List[dict], body: dict
    ) -> Union[str, Generator, Iterator]:
        logging.info(f"pipe:{__name__}")

        url = f"http://{self.host}:{self.port}/v1/chat/completions"
        headers = {"Content-Type": "application/json"}

        # Extract and validate parameters from the request body
        max_tokens = body.get("max_tokens", 4096)
        if not isinstance(max_tokens, int) or max_tokens < 0:
            max_tokens = 4096  # Default to 4096 if invalid

        temperature = body.get("temperature", 0.8)
        if not isinstance(temperature, (int, float)) or temperature < 0:
            temperature = 0.8  # Default to 0.8 if invalid

        repeat_penalty = body.get("repeat_penalty", 1.0)
        if not isinstance(repeat_penalty, (int, float)) or repeat_penalty < 0:
            repeat_penalty = 1.0  # Default to 1.0 if invalid

        payload = {
            "messages": messages,
            "max_tokens": max_tokens,
            "temperature": temperature,
            "repetition_penalty": repeat_penalty,
            "stop": self.stop_sequence,
            "stream": body.get("stream", False),
        }

        try:
            r = requests.post(
                url, headers=headers, json=payload, stream=body.get("stream", False)
            )
            r.raise_for_status()

            if body.get("stream", False):
                return r.iter_lines()
            else:
                return r.json()
        except Exception as e:
            return f"Error: {e}"