diff --git a/src/llama_cpp_runner/main.py b/src/llama_cpp_runner/main.py index 8e03a7d..6ab7b8b 100644 --- a/src/llama_cpp_runner/main.py +++ b/src/llama_cpp_runner/main.py @@ -12,9 +12,16 @@ import socket class LlamaCppServer: def __init__( - self, llama_cpp_path=None, gguf_path=None, cache_dir="./cache", verbose=False + self, + llama_cpp_path=None, + gguf_path=None, + cache_dir="./cache", + hugging_face=False, + verbose=False, + timeout_minutes=5, ): self.verbose = verbose + self.hugging_face = hugging_face self.cache_dir = cache_dir self.llama_cpp_path = llama_cpp_path self.gguf_path = gguf_path @@ -22,6 +29,9 @@ class LlamaCppServer: self._server_url = None self._server_thread = None self.port = None + self.last_used = time.time() # Tracks the last time the server was used + self.timeout_minutes = timeout_minutes + self._auto_terminate_thread = None # Fetch or validate llama path if llama_cpp_path is None: @@ -34,6 +44,7 @@ class LlamaCppServer: # Start the server if gguf_path is provided if gguf_path: self._start_server_in_thread() + self._start_auto_terminate_thread() @property def url(self): @@ -42,6 +53,8 @@ class LlamaCppServer: raise ValueError( "Server is not running. Start the server with a valid GGUF path." ) + # Update the last-used timestamp whenever this property is accessed + self.last_used = time.time() return self._server_url def kill(self): @@ -55,6 +68,31 @@ class LlamaCppServer: self._log("Llama server successfully killed.") if self._server_thread and self._server_thread.is_alive(): self._server_thread.join() + if self._auto_terminate_thread and self._auto_terminate_thread.is_alive(): + self._auto_terminate_thread.join() + + def chat_completion(self, payload): + """Send a chat completion request to the server.""" + if self._server_url is None: + raise RuntimeError( + "Server is not running. Start the server before making requests." + ) + + # Reset the last-used timestamp + self.last_used = time.time() + + endpoint = f"{self._server_url}/v1/chat/completions" + self._log(f"Sending chat completion request to {endpoint}...") + response = requests.post(endpoint, json=payload) + + if response.status_code == 200: + self._log("Request successful.") + return response.json() + else: + self._log( + f"Request failed with status code: {response.status_code} - {response.text}" + ) + response.raise_for_status() def _start_server_in_thread(self): """Start the server in a separate thread.""" @@ -68,9 +106,33 @@ class LlamaCppServer: self._server_thread = threading.Thread(target=target, daemon=True) self._server_thread.start() + def _start_auto_terminate_thread(self): + """Start the auto-terminate thread that monitors idle time.""" + + def monitor_idle_time(): + while True: + time.sleep(10) + if ( + self.server_process and self.server_process.poll() is None + ): # Server is running + elapsed_time = time.time() - self.last_used + if elapsed_time > self.timeout_minutes * 60: + self._log( + "Server has been idle for too long. Auto-terminating..." + ) + self.kill() + break + + self._auto_terminate_thread = threading.Thread( + target=monitor_idle_time, daemon=True + ) + self._auto_terminate_thread.start() + def _start_server(self): """Start the llama-server.""" - if not self.gguf_path or not os.path.exists(self.gguf_path): + if not self.gguf_path or ( + not self.hugging_face and not os.path.exists(self.gguf_path) + ): raise ValueError( f"GGUF model path is not specified or invalid: {self.gguf_path}" ) @@ -93,8 +155,15 @@ class LlamaCppServer: self._log(f"Using GGUF path: {self.gguf_path}") self._log(f"Using port: {self.port}") + commands = [server_binary] + if self.hugging_face: + commands.extend(["-hf", self.gguf_path, "--port", str(self.port)]) + else: + commands.extend(["-m", self.gguf_path, "--port", str(self.port)]) + + self._log(f"{commands}") self.server_process = subprocess.Popen( - [server_binary, "-m", self.gguf_path, "--port", str(self.port)], + commands, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True, @@ -104,7 +173,7 @@ class LlamaCppServer: self._server_url = None for line in iter(self.server_process.stdout.readline, ""): self._log(line.strip()) - if "Listening on" in line: + if "listening on" in line: self._server_url = f"http://localhost:{self.port}" self._log(f"Server is now accessible at {self._server_url}") break