diff --git a/examples/pipelines/providers/routellm_pipeline.py b/examples/pipelines/providers/routellm_pipeline.py index f7203c9..bfc0cea 100644 --- a/examples/pipelines/providers/routellm_pipeline.py +++ b/examples/pipelines/providers/routellm_pipeline.py @@ -2,12 +2,13 @@ title: RouteLLM Pipeline author: justinh-rahb date: 2024-07-25 -version: 0.2.2 +version: 0.2.3 license: MIT description: A pipeline for routing LLM requests using RouteLLM framework, compatible with OpenAI API. requirements: routellm, pydantic, requests """ +import os from typing import List, Union, Generator, Iterator from pydantic import BaseModel, Field import logging @@ -15,8 +16,9 @@ from routellm.controller import Controller class Pipeline: class Valves(BaseModel): - ROUTELLM_ROUTER: str = Field( - default="mf", description="Identifier for the RouteLLM router." + ROUTELLM_SUFFIX: str = Field( + default="OpenAI", + description="Suffix to use for model identifier and name." ) ROUTELLM_STRONG_MODEL: str = Field( default="gpt-4o", description="Identifier for the strong model." @@ -24,51 +26,33 @@ class Pipeline: ROUTELLM_WEAK_MODEL: str = Field( default="gpt-4o-mini", description="Identifier for the weak model." ) - ROUTELLM_STRONG_API_KEY: str = Field( - default="sk-your-api-key", - description="API key for accessing the strong model." - ) - ROUTELLM_WEAK_API_KEY: str = Field( - default="sk-your-api-key", - description="API key for accessing the weak model." - ) - ROUTELLM_STRONG_BASE_URL: str = Field( + ROUTELLM_BASE_URL: str = Field( default="https://api.openai.com/v1", - description="Base URL for the strong model's API." + description="Base URL for the API." ) - ROUTELLM_WEAK_BASE_URL: str = Field( - default="https://api.openai.com/v1", - description="Base URL for the weak model's API." + ROUTELLM_API_KEY: str = Field( + default="sk-your-api-key", + description="API key for accessing models." + ) + ROUTELLM_ROUTER: str = Field( + default="mf", description="Identifier for the RouteLLM routing model." ) ROUTELLM_THRESHOLD: float = Field( default=0.11593, description="Threshold value for determining when to use the strong model." ) - ROUTELLM_SUFFIX: str = Field( - default="OpenAI", - description="Suffix to use for model identifier and name." - ) def __init__(self): self.type = "manifold" - self.valves = self.Valves() self.id = "routellm" - self.name = f"RouteLLM/" + self.name = "RouteLLM/" + + # Initialize valves with environment variables if available + self.valves = self.Valves( + ROUTELLM_API_KEY=os.getenv("OPENAI_API_KEY", "") + ) + self.controller = None - - self._initialize_controller() - - def pipelines(self) -> List[dict]: - return [{"id": f"{self.valves.ROUTELLM_SUFFIX.lower()}", "name": f"{self.valves.ROUTELLM_SUFFIX}"}] - - async def on_startup(self): - logging.info(f"on_startup: {__name__}") - - async def on_shutdown(self): - logging.info(f"on_shutdown: {__name__}") - - async def on_valves_updated(self): - logging.info(f"on_valves_updated: {__name__}") self._initialize_controller() def _initialize_controller(self): @@ -76,9 +60,8 @@ class Pipeline: strong_model = self.valves.ROUTELLM_STRONG_MODEL weak_model = self.valves.ROUTELLM_WEAK_MODEL - # Set the API keys as environment variables - import os - os.environ["OPENAI_API_KEY"] = self.valves.ROUTELLM_STRONG_API_KEY + # Set the API key as an environment variable + os.environ["OPENAI_API_KEY"] = self.valves.ROUTELLM_API_KEY self.controller = Controller( routers=[self.valves.ROUTELLM_ROUTER], @@ -90,11 +73,25 @@ class Pipeline: logging.error(f"Error initializing RouteLLM controller: {e}") self.controller = None + def pipelines(self) -> List[dict]: + return [{"id": f"{self.valves.ROUTELLM_SUFFIX.lower()}", "name": f"{self.valves.ROUTELLM_SUFFIX}"}] + + async def on_startup(self): + logging.info(f"on_startup: {__name__}") + self._initialize_controller() + + async def on_shutdown(self): + logging.info(f"on_shutdown: {__name__}") + + async def on_valves_updated(self): + logging.info(f"on_valves_updated: {__name__}") + self._initialize_controller() + def pipe( self, user_message: str, model_id: str, messages: List[dict], body: dict ) -> Union[str, Generator, Iterator]: if not self.controller: - return "Error: RouteLLM controller not initialized. Please update valves with valid API keys and configuration." + return "Error: RouteLLM controller not initialized. Please update valves with valid API key and configuration." try: model_name = f"router-{self.valves.ROUTELLM_ROUTER}-{self.valves.ROUTELLM_THRESHOLD}"