From 2109df82305b6384cbc67837d9c4b1eacd5ea573 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Tue, 28 May 2024 16:58:56 -0700 Subject: [PATCH] refac: on_valves_update hook added --- main.py | 5 +- .../examples/litellm_manifold_pipeline.py | 5 +- .../litellm_subprocess_manifold_pipeline.py | 198 ++++++++++++++++++ 3 files changed, 205 insertions(+), 3 deletions(-) create mode 100644 pipelines/examples/litellm_subprocess_manifold_pipeline.py diff --git a/main.py b/main.py index 3437a56..f8fa864 100644 --- a/main.py +++ b/main.py @@ -268,18 +268,19 @@ async def update_valves(pipeline_id: str, form_data: dict): pipeline_module = PIPELINE_MODULES[pipeline_id] - await pipeline_module.on_shutdown() try: ValvesModel = pipeline_module.valves.__class__ valves = ValvesModel(**form_data) pipeline_module.valves = valves + + if hasattr(pipeline_module, "on_valves_update"): + await pipeline_module.on_valves_update() except Exception as e: print(e) raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"{str(e)}", ) - await pipeline_module.on_startup() return pipeline_module.valves diff --git a/pipelines/examples/litellm_manifold_pipeline.py b/pipelines/examples/litellm_manifold_pipeline.py index 8d6d668..dbe1fe0 100644 --- a/pipelines/examples/litellm_manifold_pipeline.py +++ b/pipelines/examples/litellm_manifold_pipeline.py @@ -31,7 +31,6 @@ class Pipeline: async def on_startup(self): # This function is called when the server is started or after valves are updated. print(f"on_startup:{__name__}") - self.pipelines = self.get_litellm_models() pass async def on_shutdown(self): @@ -39,6 +38,10 @@ class Pipeline: print(f"on_shutdown:{__name__}") pass + async def on_valves_update(self): + self.pipelines = self.get_litellm_models() + pass + def get_litellm_models(self): if self.valves.LITELLM_BASE_URL: try: diff --git a/pipelines/examples/litellm_subprocess_manifold_pipeline.py b/pipelines/examples/litellm_subprocess_manifold_pipeline.py new file mode 100644 index 0000000..88ff220 --- /dev/null +++ b/pipelines/examples/litellm_subprocess_manifold_pipeline.py @@ -0,0 +1,198 @@ +from typing import List, Union, Generator, Iterator +from schemas import OpenAIChatMessage +from pydantic import BaseModel +import requests + + +import os +import asyncio +import subprocess +import yaml + + +class Pipeline: + def __init__(self): + # You can also set the pipelines that are available in this pipeline. + # Set manifold to True if you want to use this pipeline as a manifold. + # Manifold pipelines can have multiple pipelines. + self.type = "manifold" + + # Optionally, you can set the id and name of the pipeline. + # Assign a unique identifier to the pipeline. + # The identifier must be unique across all pipelines. + # The identifier must be an alphanumeric string that can include underscores or hyphens. It cannot contain spaces, special characters, slashes, or backslashes. + self.id = "litellm_subprocess_manifold" + + # Optionally, you can set the name of the manifold pipeline. + self.name = "LiteLLM: " + + class Valves(BaseModel): + LITELLM_CONFIG_DIR: str = "./litellm/config.yaml" + LITELLM_PROXY_PORT: int = 4001 + LITELLM_PROXY_HOST: str = "127.0.0.1" + litellm_config: dict = {} + + # Initialize rate limits + self.valves = Valves(**{"LITELLM_CONFIG_DIR": f"./litellm/config.yaml"}) + self.pipelines = [] + + self.background_process = None + pass + + async def on_startup(self): + # This function is called when the server is started or after valves are updated. + print(f"on_startup:{__name__}") + + # Check if the config file exists + if not os.path.exists(self.valves.LITELLM_CONFIG_DIR): + with open(self.valves.LITELLM_CONFIG_DIR, "w") as file: + yaml.dump( + { + "general_settings": {}, + "litellm_settings": {}, + "model_list": [], + "router_settings": {}, + }, + file, + ) + + print( + f"Config file not found. Created a default config file at {self.valves.LITELLM_CONFIG_DIR}" + ) + + with open(self.valves.LITELLM_CONFIG_DIR, "r") as file: + litellm_config = yaml.safe_load(file) + + self.valves.litellm_config = litellm_config + + asyncio.create_task(self.start_litellm_background()) + pass + + async def on_shutdown(self): + # This function is called when the server is stopped or before valves are updated. + print(f"on_shutdown:{__name__}") + await self.shutdown_litellm_background() + pass + + async def on_valves_update(self): + print(f"on_valves_update:{__name__}") + + with open(self.valves.LITELLM_CONFIG_DIR, "r") as file: + litellm_config = yaml.safe_load(file) + + self.valves.litellm_config = litellm_config + + await self.shutdown_litellm_background() + await self.start_litellm_background() + pass + + async def run_background_process(self, command): + print("run_background_process") + + try: + # Log the command to be executed + print(f"Executing command: {command}") + + # Execute the command and create a subprocess + process = await asyncio.create_subprocess_exec( + *command, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + self.background_process = process + print("Subprocess started successfully.") + + self.pipelines = self.get_litellm_models() + + # Capture STDERR for debugging purposes + stderr_output = await process.stderr.read() + stderr_text = stderr_output.decode().strip() + if stderr_text: + print(f"Subprocess STDERR: {stderr_text}") + + # log.info output line by line + async for line in process.stdout: + print(line.decode().strip()) + + # Wait for the process to finish + returncode = await process.wait() + print(f"Subprocess exited with return code {returncode}") + except Exception as e: + print(f"Failed to start subprocess: {e}") + raise # Optionally re-raise the exception if you want it to propagate + + async def start_litellm_background(self): + print("start_litellm_background") + # Command to run in the background + command = [ + "litellm", + "--port", + str(self.valves.LITELLM_PROXY_PORT), + "--host", + self.valves.LITELLM_PROXY_HOST, + "--telemetry", + "False", + "--config", + self.valves.LITELLM_CONFIG_DIR, + ] + + await self.run_background_process(command) + + async def shutdown_litellm_background(self): + print("shutdown_litellm_background") + + if self.background_process: + self.background_process.terminate() + await self.background_process.wait() # Ensure the process has terminated + print("Subprocess terminated") + self.background_process = None + + def get_litellm_models(self): + if self.background_process: + try: + r = requests.get( + f"http://{self.valves.LITELLM_PROXY_HOST}:{self.valves.LITELLM_PROXY_PORT}/v1/models" + ) + models = r.json() + return [ + { + "id": model["id"], + "name": model["name"] if "name" in model else model["id"], + } + for model in models["data"] + ] + except Exception as e: + print(f"Error: {e}") + return [ + { + "id": self.id, + "name": "Could not fetch models from LiteLLM, please update the URL in the valves.", + }, + ] + else: + return [] + + def pipe( + self, user_message: str, model_id: str, messages: List[dict], body: dict + ) -> Union[str, Generator, Iterator]: + if "user" in body: + print("######################################") + print(f'# User: {body["user"]["name"]} ({body["user"]["id"]})') + print(f"# Message: {user_message}") + print("######################################") + + try: + r = requests.post( + url=f"http://{self.valves.LITELLM_PROXY_HOST}:{self.valves.LITELLM_PROXY_PORT}/v1/chat/completions", + json={**body, "model": model_id, "user_id": body["user"]["id"]}, + stream=True, + ) + + r.raise_for_status() + + if body["stream"]: + return r.iter_lines() + else: + return r.json() + except Exception as e: + return f"Error: {e}"