This commit is contained in:
Timothy Jaeryang Baek 2025-01-28 00:32:05 -08:00
parent ff84790ada
commit 58f262ba8d

View File

@ -12,9 +12,16 @@ import socket
class LlamaCppServer:
def __init__(
self, llama_cpp_path=None, gguf_path=None, cache_dir="./cache", verbose=False
self,
llama_cpp_path=None,
gguf_path=None,
cache_dir="./cache",
hugging_face=False,
verbose=False,
timeout_minutes=5,
):
self.verbose = verbose
self.hugging_face = hugging_face
self.cache_dir = cache_dir
self.llama_cpp_path = llama_cpp_path
self.gguf_path = gguf_path
@ -22,6 +29,9 @@ class LlamaCppServer:
self._server_url = None
self._server_thread = None
self.port = None
self.last_used = time.time() # Tracks the last time the server was used
self.timeout_minutes = timeout_minutes
self._auto_terminate_thread = None
# Fetch or validate llama path
if llama_cpp_path is None:
@ -34,6 +44,7 @@ class LlamaCppServer:
# Start the server if gguf_path is provided
if gguf_path:
self._start_server_in_thread()
self._start_auto_terminate_thread()
@property
def url(self):
@ -42,6 +53,8 @@ class LlamaCppServer:
raise ValueError(
"Server is not running. Start the server with a valid GGUF path."
)
# Update the last-used timestamp whenever this property is accessed
self.last_used = time.time()
return self._server_url
def kill(self):
@ -55,6 +68,31 @@ class LlamaCppServer:
self._log("Llama server successfully killed.")
if self._server_thread and self._server_thread.is_alive():
self._server_thread.join()
if self._auto_terminate_thread and self._auto_terminate_thread.is_alive():
self._auto_terminate_thread.join()
def chat_completion(self, payload):
"""Send a chat completion request to the server."""
if self._server_url is None:
raise RuntimeError(
"Server is not running. Start the server before making requests."
)
# Reset the last-used timestamp
self.last_used = time.time()
endpoint = f"{self._server_url}/v1/chat/completions"
self._log(f"Sending chat completion request to {endpoint}...")
response = requests.post(endpoint, json=payload)
if response.status_code == 200:
self._log("Request successful.")
return response.json()
else:
self._log(
f"Request failed with status code: {response.status_code} - {response.text}"
)
response.raise_for_status()
def _start_server_in_thread(self):
"""Start the server in a separate thread."""
@ -68,9 +106,33 @@ class LlamaCppServer:
self._server_thread = threading.Thread(target=target, daemon=True)
self._server_thread.start()
def _start_auto_terminate_thread(self):
"""Start the auto-terminate thread that monitors idle time."""
def monitor_idle_time():
while True:
time.sleep(10)
if (
self.server_process and self.server_process.poll() is None
): # Server is running
elapsed_time = time.time() - self.last_used
if elapsed_time > self.timeout_minutes * 60:
self._log(
"Server has been idle for too long. Auto-terminating..."
)
self.kill()
break
self._auto_terminate_thread = threading.Thread(
target=monitor_idle_time, daemon=True
)
self._auto_terminate_thread.start()
def _start_server(self):
"""Start the llama-server."""
if not self.gguf_path or not os.path.exists(self.gguf_path):
if not self.gguf_path or (
not self.hugging_face and not os.path.exists(self.gguf_path)
):
raise ValueError(
f"GGUF model path is not specified or invalid: {self.gguf_path}"
)
@ -93,8 +155,15 @@ class LlamaCppServer:
self._log(f"Using GGUF path: {self.gguf_path}")
self._log(f"Using port: {self.port}")
commands = [server_binary]
if self.hugging_face:
commands.extend(["-hf", self.gguf_path, "--port", str(self.port)])
else:
commands.extend(["-m", self.gguf_path, "--port", str(self.port)])
self._log(f"{commands}")
self.server_process = subprocess.Popen(
[server_binary, "-m", self.gguf_path, "--port", str(self.port)],
commands,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
universal_newlines=True,
@ -104,7 +173,7 @@ class LlamaCppServer:
self._server_url = None
for line in iter(self.server_process.stdout.readline, ""):
self._log(line.strip())
if "Listening on" in line:
if "listening on" in line:
self._server_url = f"http://localhost:{self.port}"
self._log(f"Server is now accessible at {self._server_url}")
break