Revert all recent changes

This commit is contained in:
Justin Hayes 2024-05-22 16:03:10 -04:00
parent cae961c2d4
commit de765a1391

View File

@ -1,21 +1,16 @@
""" """
title: MLX Pipeline Plugin Name: MLX Pipeline
author: justinh-rahb Description: A pipeline for running the mlx-lm server with a specified model and dynamically allocated port.
date: 2024-05-22 Author: justinh-rahb
version: 1.0 License: MIT
license: MIT Python Dependencies: requests, subprocess, os, socket, schemas
description: A pipeline for running the mlx-lm server with a specified model.
dependencies: requests, mlx-lm, huggingface_hub
environment_variables: MLX_MODEL, MLX_STOP, HUGGINGFACE_TOKEN
""" """
from typing import List, Union, Generator, Iterator from typing import List, Union, Generator, Iterator
import requests
import subprocess import subprocess
import os import os
import socket import socket
import time
import requests
from huggingface_hub import login
from schemas import OpenAIChatMessage from schemas import OpenAIChatMessage
@ -28,11 +23,6 @@ class Pipeline:
self.model = os.getenv('MLX_MODEL', 'mistralai/Mistral-7B-Instruct-v0.2') # Default model if not set in environment variable self.model = os.getenv('MLX_MODEL', 'mistralai/Mistral-7B-Instruct-v0.2') # Default model if not set in environment variable
self.port = self.find_free_port() self.port = self.find_free_port()
self.stop_sequences = os.getenv('MLX_STOP', '[INST]') # Stop sequences from environment variable self.stop_sequences = os.getenv('MLX_STOP', '[INST]') # Stop sequences from environment variable
self.hf_token = os.getenv('HUGGINGFACE_TOKEN', None) # Hugging Face token from environment variable
# Authenticate with Hugging Face if a token is provided
if self.hf_token:
self.authenticate_huggingface(self.hf_token)
@staticmethod @staticmethod
def find_free_port(): def find_free_port():
@ -41,14 +31,6 @@ class Pipeline:
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
return s.getsockname()[1] return s.getsockname()[1]
@staticmethod
def authenticate_huggingface(token: str):
try:
login(token)
print("Successfully authenticated with Hugging Face.")
except Exception as e:
print(f"Failed to authenticate with Hugging Face: {e}")
async def on_startup(self): async def on_startup(self):
# This function is called when the server is started. # This function is called when the server is started.
print(f"on_startup:{__name__}") print(f"on_startup:{__name__}")
@ -68,15 +50,8 @@ class Pipeline:
stderr=subprocess.PIPE stderr=subprocess.PIPE
) )
print(f"Subprocess started with PID: {self.process.pid} on port {self.port}") print(f"Subprocess started with PID: {self.process.pid} on port {self.port}")
# Check if the process has started correctly
time.sleep(2) # Give it a moment to start
if self.process.poll() is not None:
raise RuntimeError(f"Subprocess failed to start. Return code: {self.process.returncode}")
except Exception as e: except Exception as e:
print(f"Failed to start subprocess: {e}") print(f"Failed to start subprocess: {e}")
self.process = None
def stop_subprocess(self): def stop_subprocess(self):
# Stop the subprocess if it is running # Stop the subprocess if it is running
@ -87,8 +62,6 @@ class Pipeline:
print(f"Subprocess with PID {self.process.pid} terminated") print(f"Subprocess with PID {self.process.pid} terminated")
except Exception as e: except Exception as e:
print(f"Failed to terminate subprocess: {e}") print(f"Failed to terminate subprocess: {e}")
finally:
self.process = None
def get_response( def get_response(
self, user_message: str, messages: List[OpenAIChatMessage], body: dict self, user_message: str, messages: List[OpenAIChatMessage], body: dict
@ -96,9 +69,6 @@ class Pipeline:
# This is where you can add your custom pipelines like RAG.' # This is where you can add your custom pipelines like RAG.'
print(f"get_response:{__name__}") print(f"get_response:{__name__}")
if not self.process or self.process.poll() is not None:
return "Error: Subprocess is not running."
MLX_BASE_URL = f"http://localhost:{self.port}" MLX_BASE_URL = f"http://localhost:{self.port}"
MODEL = self.model MODEL = self.model
@ -106,8 +76,8 @@ class Pipeline:
messages_dict = [{"role": message.role, "content": message.content} for message in messages] messages_dict = [{"role": message.role, "content": message.content} for message in messages]
# Extract additional parameters from the body # Extract additional parameters from the body
temperature = body.get("temperature", 0.8) temperature = body.get("temperature", 1.0)
max_tokens = body.get("max_tokens", 1000) max_tokens = body.get("max_tokens", 100)
top_p = body.get("top_p", 1.0) top_p = body.get("top_p", 1.0)
repetition_penalty = body.get("repetition_penalty", 1.0) repetition_penalty = body.get("repetition_penalty", 1.0)
stop = self.stop_sequences stop = self.stop_sequences