feat: LlamaCpp class

This commit is contained in:
Timothy Jaeryang Baek 2025-01-28 13:27:46 -08:00
parent 2b057ff73e
commit 3d15266ea6

View File

@ -10,6 +10,231 @@ import time
import socket import socket
class LlamaCpp:
def __init__(
self, models_dir, cache_dir="./cache", verbose=False, timeout_minutes=5
):
"""
Initialize the LlamaCpp class.
Args:
models_dir (str): Directory where GGUF models are stored.
cache_dir (str): Directory to store llama.cpp binaries and related assets.
verbose (bool): Whether to enable verbose logging.
timeout_minutes (int): Timeout for shutting down idle servers.
"""
self.models_dir = models_dir
self.cache_dir = cache_dir
self.verbose = verbose
self.timeout_minutes = timeout_minutes
self.llama_cpp_path = (
self._install_llama_cpp_binaries()
) # Handle binaries installation
self.servers = (
{}
) # Maintain a mapping of model names to LlamaCppServer instances
def list_models(self):
"""
List all GGUF models available in the `models_dir`.
Returns:
list: A list of model names (files ending in ".gguf").
"""
if not os.path.exists(self.models_dir):
self._log(f"Models directory does not exist: {self.models_dir}")
return []
models = [f for f in os.listdir(self.models_dir) if f.endswith(".gguf")]
self._log(f"Available models: {models}")
return models
def chat_completion(self, body):
"""
Handle chat completion requests.
Args:
body (dict): The payload for the chat completion request. It must contain the "model" key.
Returns:
dict or generator: Response from the server (non-streaming or streaming mode).
"""
if "model" not in body:
raise ValueError("The request body must contain a 'model' key.")
model_name = body["model"]
gguf_path = os.path.join(self.models_dir, model_name)
if not os.path.exists(gguf_path):
raise FileNotFoundError(f"Model file not found: {gguf_path}")
# Check if the server for this model is already running
if model_name not in self.servers or not self.servers[model_name]._server_url:
self._log(f"Initializing a new server for model: {model_name}")
self.servers[model_name] = self._create_server(gguf_path)
server = self.servers[model_name]
return server.chat_completion(body)
def _create_server(self, gguf_path):
"""
Create a new LlamaCppServer instance for the given model.
Args:
gguf_path (str): Path to the GGUF model file.
Returns:
LlamaCppServer: A new server instance.
"""
return LlamaCppServer(
llama_cpp_path=self.llama_cpp_path,
gguf_path=gguf_path,
cache_dir=self.cache_dir,
verbose=self.verbose,
timeout_minutes=self.timeout_minutes,
)
def _install_llama_cpp_binaries(self):
"""
Download and install llama.cpp binaries.
Returns:
str: Path to the installed llama.cpp binaries.
"""
self._log("Installing llama.cpp binaries...")
release_info = self._get_latest_release()
assets = release_info["assets"]
asset = self._get_appropriate_asset(assets)
if not asset:
raise RuntimeError("No appropriate binary found for your system.")
asset_name = asset["name"]
if self._check_cache(release_info, asset):
self._log("Using cached llama.cpp binaries.")
else:
self._download_and_unzip(asset["browser_download_url"], asset_name)
self._update_cache_info(release_info, asset)
return os.path.join(self.cache_dir, "llama_cpp")
def _get_latest_release(self):
"""
Fetch the latest release of llama.cpp from GitHub.
Returns:
dict: Release information.
"""
api_url = "https://api.github.com/repos/ggerganov/llama.cpp/releases/latest"
response = requests.get(api_url)
if response.status_code == 200:
return response.json()
else:
raise RuntimeError(
f"Failed to fetch release info. Status code: {response.status_code}"
)
def _get_appropriate_asset(self, assets):
"""
Select the appropriate binary asset for the current system.
Args:
assets (list): List of asset metadata from the release.
Returns:
dict or None: Matching asset metadata, or None if no match found.
"""
system = platform.system().lower()
machine = platform.machine().lower()
processor = platform.processor()
if system == "windows":
if "arm" in machine:
return next((a for a in assets if "win-arm64" in a["name"]), None)
elif "avx512" in processor:
return next((a for a in assets if "win-avx512-x64" in a["name"]), None)
elif "avx2" in processor:
return next((a for a in assets if "win-avx2-x64" in a["name"]), None)
elif "avx" in processor:
return next((a for a in assets if "win-avx-x64" in a["name"]), None)
else:
return next((a for a in assets if "win-noavx-x64" in a["name"]), None)
elif system == "darwin":
if "arm" in machine:
return next((a for a in assets if "macos-arm64" in a["name"]), None)
else:
return next((a for a in assets if "macos-x64" in a["name"]), None)
elif system == "linux":
return next((a for a in assets if "ubuntu-x64" in a["name"]), None)
return None
def _check_cache(self, release_info, asset):
"""
Check whether the latest binaries are already cached.
Args:
release_info (dict): Metadata of the latest release.
asset (dict): Metadata of the selected asset.
Returns:
bool: True if the cached binary matches the latest release, False otherwise.
"""
cache_info_path = os.path.join(self.cache_dir, "cache_info.json")
if os.path.exists(cache_info_path):
with open(cache_info_path, "r") as f:
cache_info = json.load(f)
if (
cache_info.get("tag_name") == release_info["tag_name"]
and cache_info.get("asset_name") == asset["name"]
):
return True
return False
def _download_and_unzip(self, url, asset_name):
"""
Download and extract llama.cpp binaries.
Args:
url (str): URL of the asset to download.
asset_name (str): Name of the asset file.
"""
os.makedirs(self.cache_dir, exist_ok=True)
zip_path = os.path.join(self.cache_dir, asset_name)
self._log(f"Downloading binary from: {url}")
response = requests.get(url)
if response.status_code == 200:
with open(zip_path, "wb") as file:
file.write(response.content)
self._log(f"Successfully downloaded: {asset_name}")
else:
raise RuntimeError(f"Failed to download binary: {url}")
extract_dir = os.path.join(self.cache_dir, "llama_cpp")
with zipfile.ZipFile(zip_path, "r") as zip_ref:
zip_ref.extractall(extract_dir)
self._log(f"Extracted binaries to: {extract_dir}")
def _update_cache_info(self, release_info, asset):
"""
Update cache metadata with the downloaded release info.
Args:
release_info (dict): Metadata of the latest release.
asset (dict): Metadata of the downloaded asset.
"""
cache_info = {"tag_name": release_info["tag_name"], "asset_name": asset["name"]}
cache_info_path = os.path.join(self.cache_dir, "cache_info.json")
with open(cache_info_path, "w") as f:
json.dump(cache_info, f)
def _log(self, message):
"""
Print a log message if verbosity is enabled.
Args:
message (str): Log message to print.
"""
if self.verbose:
print(f"[LlamaCpp] {message}")
class LlamaCppServer: class LlamaCppServer:
def __init__( def __init__(
self, self,
@ -20,6 +245,17 @@ class LlamaCppServer:
verbose=False, verbose=False,
timeout_minutes=5, timeout_minutes=5,
): ):
"""
Initialize the LlamaCppServer.
Args:
llama_cpp_path (str): Path to the llama.cpp binaries.
gguf_path (str): Path to the GGUF model file.
cache_dir (str): Directory to store llama.cpp binaries and related files.
hugging_face (bool): Whether the model is hosted on Hugging Face.
verbose (bool): Enable verbose logging.
timeout_minutes (int): Timeout duration for shutting down idle servers.
"""
self.verbose = verbose self.verbose = verbose
self.hugging_face = hugging_face self.hugging_face = hugging_face
self.cache_dir = cache_dir self.cache_dir = cache_dir
@ -33,14 +269,18 @@ class LlamaCppServer:
self.timeout_minutes = timeout_minutes self.timeout_minutes = timeout_minutes
self._auto_terminate_thread = None self._auto_terminate_thread = None
# Fetch or validate llama path # Validate llama_cpp_path
if llama_cpp_path is None: if llama_cpp_path is None:
self.llama_cpp_path = self._install_llama_cpp_binaries() raise ValueError("llama_cpp_path must be provided.")
elif not os.path.exists(llama_cpp_path): elif not os.path.exists(llama_cpp_path):
raise FileNotFoundError( raise FileNotFoundError(
f"Specified llama_cpp_path not found: {llama_cpp_path}" f"Specified llama_cpp_path not found: {llama_cpp_path}"
) )
# Validate gguf_path
if gguf_path and not os.path.exists(gguf_path) and not hugging_face:
raise FileNotFoundError(f"Specified gguf_path not found: {gguf_path}")
# Start the server if gguf_path is provided # Start the server if gguf_path is provided
if gguf_path: if gguf_path:
self._start_server_in_thread() self._start_server_in_thread()
@ -57,6 +297,7 @@ class LlamaCppServer:
# Wait for the thread to start the server # Wait for the thread to start the server
while self._server_url is None: while self._server_url is None:
time.sleep(1) time.sleep(1)
# Update the last-used timestamp whenever this property is accessed # Update the last-used timestamp whenever this property is accessed
self.last_used = time.time() self.last_used = time.time()
return self._server_url return self._server_url
@ -70,13 +311,23 @@ class LlamaCppServer:
self._server_url = None self._server_url = None
self.port = None self.port = None
self._log("Llama server successfully killed.") self._log("Llama server successfully killed.")
if self._server_thread and self._server_thread.is_alive(): if self._server_thread and self._server_thread.is_alive():
self._server_thread.join() self._server_thread.join()
if self._auto_terminate_thread and self._auto_terminate_thread.is_alive(): if self._auto_terminate_thread and self._auto_terminate_thread.is_alive():
self._auto_terminate_thread.join() self._auto_terminate_thread.join()
def chat_completion(self, payload): def chat_completion(self, payload):
"""Send a chat completion request to the server.""" """
Send a chat completion request to the server.
Args:
payload (dict): Payload for the chat completion request.
Returns:
dict or generator: Response from the server (non-streaming or streaming mode).
"""
if self._server_url is None: if self._server_url is None:
self._log( self._log(
"Server is off. Restarting the server before making the request..." "Server is off. Restarting the server before making the request..."
@ -86,19 +337,39 @@ class LlamaCppServer:
# Wait for the thread to start the server # Wait for the thread to start the server
while self._server_url is None: while self._server_url is None:
time.sleep(1) time.sleep(1)
# Reset the last-used timestamp # Reset the last-used timestamp
self.last_used = time.time() self.last_used = time.time()
endpoint = f"{self._server_url}/v1/chat/completions" endpoint = f"{self._server_url}/v1/chat/completions"
self._log(f"Sending chat completion request to {endpoint}...") self._log(f"Sending chat completion request to {endpoint}...")
response = requests.post(endpoint, json=payload)
if response.status_code == 200: # Check if streaming is enabled in the payload
self._log("Request successful.") if payload.get("stream", False):
return response.json() self._log(f"Streaming mode enabled. Returning a generator.")
response = requests.post(endpoint, json=payload, stream=True)
if response.status_code == 200:
# Return a generator for streaming responses
def stream_response():
for line in response.iter_lines(decode_unicode=True):
yield line
return stream_response()
else:
self._log(
f"Request failed with status code: {response.status_code} - {response.text}"
)
response.raise_for_status()
else: else:
self._log( # Non-streaming mode
f"Request failed with status code: {response.status_code} - {response.text}" response = requests.post(endpoint, json=payload)
) if response.status_code == 200:
response.raise_for_status() self._log("Request successful.")
return response.json()
else:
self._log(
f"Request failed with status code: {response.status_code} - {response.text}"
)
response.raise_for_status()
def _start_server_in_thread(self): def _start_server_in_thread(self):
"""Start the server in a separate thread.""" """Start the server in a separate thread."""
@ -142,13 +413,16 @@ class LlamaCppServer:
raise ValueError( raise ValueError(
f"GGUF model path is not specified or invalid: {self.gguf_path}" f"GGUF model path is not specified or invalid: {self.gguf_path}"
) )
server_binary = os.path.join( server_binary = os.path.join(
self.llama_cpp_path, "build", "bin", "llama-server" self.llama_cpp_path, "build", "bin", "llama-server"
) )
if not os.path.exists(server_binary): if not os.path.exists(server_binary):
raise FileNotFoundError(f"Server binary not found: {server_binary}") raise FileNotFoundError(f"Server binary not found: {server_binary}")
# Ensure the binary is executable # Ensure the binary is executable
self._set_executable(server_binary) self._set_executable(server_binary)
# Find an available port # Find an available port
self.port = self._find_available_port(start_port=10000) self.port = self._find_available_port(start_port=10000)
if self.port is None: if self.port is None:
@ -164,7 +438,6 @@ class LlamaCppServer:
else: else:
commands.extend(["-m", self.gguf_path, "--port", str(self.port)]) commands.extend(["-m", self.gguf_path, "--port", str(self.port)])
self._log(f"{commands}")
self.server_process = subprocess.Popen( self.server_process = subprocess.Popen(
commands, commands,
stdout=subprocess.PIPE, stdout=subprocess.PIPE,
@ -204,92 +477,3 @@ class LlamaCppServer:
"""Print a log message if verbosity is enabled.""" """Print a log message if verbosity is enabled."""
if self.verbose: if self.verbose:
print(f"[LlamaCppServer] {message}") print(f"[LlamaCppServer] {message}")
def _install_llama_cpp_binaries(self):
"""Download and install llama.cpp binaries."""
self._log("Installing llama.cpp binaries...")
release_info = self._get_latest_release()
assets = release_info["assets"]
asset = self._get_appropriate_asset(assets)
if not asset:
raise RuntimeError("No appropriate binary found for your system.")
asset_name = asset["name"]
if self._check_cache(release_info, asset):
self._log("Using cached llama.cpp binaries.")
else:
self._download_and_unzip(asset["browser_download_url"], asset_name)
self._update_cache_info(release_info, asset)
return os.path.join(self.cache_dir, "llama_cpp")
def _get_latest_release(self):
"""Fetch the latest release of llama.cpp from GitHub."""
api_url = "https://api.github.com/repos/ggerganov/llama.cpp/releases/latest"
response = requests.get(api_url)
if response.status_code == 200:
return response.json()
else:
raise RuntimeError(
f"Failed to fetch release info. Status code: {response.status_code}"
)
def _get_appropriate_asset(self, assets):
"""Select the appropriate binary asset for the current system."""
system = platform.system().lower()
machine = platform.machine().lower()
processor = platform.processor()
if system == "windows":
if "arm" in machine:
return next((a for a in assets if "win-arm64" in a["name"]), None)
elif "avx512" in processor:
return next((a for a in assets if "win-avx512-x64" in a["name"]), None)
elif "avx2" in processor:
return next((a for a in assets if "win-avx2-x64" in a["name"]), None)
elif "avx" in processor:
return next((a for a in assets if "win-avx-x64" in a["name"]), None)
else:
return next((a for a in assets if "win-noavx-x64" in a["name"]), None)
elif system == "darwin":
if "arm" in machine:
return next((a for a in assets if "macos-arm64" in a["name"]), None)
else:
return next((a for a in assets if "macos-x64" in a["name"]), None)
elif system == "linux":
return next((a for a in assets if "ubuntu-x64" in a["name"]), None)
return None
def _check_cache(self, release_info, asset):
"""Check whether the latest binaries are already cached."""
cache_info_path = os.path.join(self.cache_dir, "cache_info.json")
if os.path.exists(cache_info_path):
with open(cache_info_path, "r") as f:
cache_info = json.load(f)
if (
cache_info.get("tag_name") == release_info["tag_name"]
and cache_info.get("asset_name") == asset["name"]
):
return True
return False
def _download_and_unzip(self, url, asset_name):
"""Download and extract llama.cpp binaries."""
os.makedirs(self.cache_dir, exist_ok=True)
zip_path = os.path.join(self.cache_dir, asset_name)
self._log(f"Downloading binary from: {url}")
response = requests.get(url)
if response.status_code == 200:
with open(zip_path, "wb") as file:
file.write(response.content)
self._log(f"Successfully downloaded: {asset_name}")
else:
raise RuntimeError(f"Failed to download binary: {url}")
extract_dir = os.path.join(self.cache_dir, "llama_cpp")
with zipfile.ZipFile(zip_path, "r") as zip_ref:
zip_ref.extractall(extract_dir)
self._log(f"Extracted binaries to: {extract_dir}")
def _update_cache_info(self, release_info, asset):
"""Update cache metadata with the downloaded release info."""
cache_info = {"tag_name": release_info["tag_name"], "asset_name": asset["name"]}
cache_info_path = os.path.join(self.cache_dir, "cache_info.json")
with open(cache_info_path, "w") as f:
json.dump(cache_info, f)