Merge branch 'open-webui:main' into mlx-manifold

This commit is contained in:
Justin Hayes 2024-07-01 10:55:07 -04:00 committed by GitHub
commit 7743a40b41
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -2,55 +2,55 @@
title: MLX Pipeline title: MLX Pipeline
author: justinh-rahb author: justinh-rahb
date: 2024-05-27 date: 2024-05-27
version: 1.1 version: 1.2
license: MIT license: MIT
description: A pipeline for generating text using Apple MLX Framework. description: A pipeline for generating text using Apple MLX Framework.
requirements: requests, mlx-lm, huggingface-hub requirements: requests, mlx-lm, huggingface-hub
environment_variables: MLX_HOST, MLX_PORT, MLX_MODEL, MLX_STOP, MLX_SUBPROCESS, HUGGINGFACE_TOKEN environment_variables: MLX_HOST, MLX_PORT, MLX_SUBPROCESS
""" """
from typing import List, Union, Generator, Iterator from typing import List, Union, Generator, Iterator
from schemas import OpenAIChatMessage from schemas import OpenAIChatMessage
from pydantic import BaseModel
import requests import requests
import os import os
import subprocess import subprocess
import logging import logging
from huggingface_hub import login from huggingface_hub import login
class Pipeline: class Pipeline:
class Valves(BaseModel):
MLX_MODEL: str = "mistralai/Mistral-7B-Instruct-v0.3"
MLX_STOP: str = "[INST]"
HUGGINGFACE_TOKEN: str = ""
def __init__(self): def __init__(self):
# Optionally, you can set the id and name of the pipeline. self.id = "mlx_pipeline"
# Best practice is to not specify the id so that it can be automatically inferred from the filename, so that users can install multiple versions of the same pipeline.
# The identifier must be unique across all pipelines.
# The identifier must be an alphanumeric string that can include underscores or hyphens. It cannot contain spaces, special characters, slashes, or backslashes.
# self.id = "mlx_pipeline"
self.name = "MLX Pipeline" self.name = "MLX Pipeline"
self.valves = self.Valves()
self.update_valves()
self.host = os.getenv("MLX_HOST", "localhost") self.host = os.getenv("MLX_HOST", "localhost")
self.port = os.getenv("MLX_PORT", "8080") self.port = os.getenv("MLX_PORT", "8080")
self.model = os.getenv("MLX_MODEL", "mistralai/Mistral-7B-Instruct-v0.2")
self.stop_sequence = os.getenv("MLX_STOP", "[INST]").split(
","
) # Default stop sequence is [INST]
self.subprocess = os.getenv("MLX_SUBPROCESS", "true").lower() == "true" self.subprocess = os.getenv("MLX_SUBPROCESS", "true").lower() == "true"
self.huggingface_token = os.getenv("HUGGINGFACE_TOKEN", None)
if self.huggingface_token:
login(self.huggingface_token)
if self.subprocess: if self.subprocess:
self.start_mlx_server() self.start_mlx_server()
def update_valves(self):
if self.valves.HUGGINGFACE_TOKEN:
login(self.valves.HUGGINGFACE_TOKEN)
self.stop_sequence = self.valves.MLX_STOP.split(",")
def start_mlx_server(self): def start_mlx_server(self):
if not os.getenv("MLX_PORT"): if not os.getenv("MLX_PORT"):
self.port = self.find_free_port() self.port = self.find_free_port()
command = f"mlx_lm.server --model {self.model} --port {self.port}" command = f"mlx_lm.server --model {self.valves.MLX_MODEL} --port {self.port}"
self.server_process = subprocess.Popen(command, shell=True) self.server_process = subprocess.Popen(command, shell=True)
logging.info(f"Started MLX server on port {self.port}") logging.info(f"Started MLX server on port {self.port}")
def find_free_port(self): def find_free_port(self):
import socket import socket
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.bind(("", 0)) s.bind(("", 0))
port = s.getsockname()[1] port = s.getsockname()[1]
@ -65,6 +65,13 @@ class Pipeline:
self.server_process.terminate() self.server_process.terminate()
logging.info(f"Terminated MLX server on port {self.port}") logging.info(f"Terminated MLX server on port {self.port}")
async def on_valves_updated(self):
self.update_valves()
if self.subprocess and hasattr(self, "server_process"):
self.server_process.terminate()
logging.info(f"Terminated MLX server on port {self.port}")
self.start_mlx_server()
def pipe( def pipe(
self, user_message: str, model_id: str, messages: List[dict], body: dict self, user_message: str, model_id: str, messages: List[dict], body: dict
) -> Union[str, Generator, Iterator]: ) -> Union[str, Generator, Iterator]:
@ -73,18 +80,17 @@ class Pipeline:
url = f"http://{self.host}:{self.port}/v1/chat/completions" url = f"http://{self.host}:{self.port}/v1/chat/completions"
headers = {"Content-Type": "application/json"} headers = {"Content-Type": "application/json"}
# Extract and validate parameters from the request body
max_tokens = body.get("max_tokens", 4096) max_tokens = body.get("max_tokens", 4096)
if not isinstance(max_tokens, int) or max_tokens < 0: if not isinstance(max_tokens, int) or max_tokens < 0:
max_tokens = 4096 # Default to 4096 if invalid max_tokens = 4096
temperature = body.get("temperature", 0.8) temperature = body.get("temperature", 0.8)
if not isinstance(temperature, (int, float)) or temperature < 0: if not isinstance(temperature, (int, float)) or temperature < 0:
temperature = 0.8 # Default to 0.8 if invalid temperature = 0.8
repeat_penalty = body.get("repeat_penalty", 1.0) repeat_penalty = body.get("repeat_penalty", 1.0)
if not isinstance(repeat_penalty, (int, float)) or repeat_penalty < 0: if not isinstance(repeat_penalty, (int, float)) or repeat_penalty < 0:
repeat_penalty = 1.0 # Default to 1.0 if invalid repeat_penalty = 1.0
payload = { payload = {
"messages": messages, "messages": messages,
@ -106,4 +112,4 @@ class Pipeline:
else: else:
return r.json() return r.json()
except Exception as e: except Exception as e:
return f"Error: {e}" return f"Error: {e}"