enh: provider-agnostic

This commit is contained in:
Justin Hayes 2024-07-26 09:31:27 -04:00 committed by GitHub
parent 95851a78fa
commit e49419ccd3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -2,7 +2,7 @@
title: RouteLLM Pipeline title: RouteLLM Pipeline
author: justinh-rahb author: justinh-rahb
date: 2024-07-25 date: 2024-07-25
version: 0.1.0 version: 0.2.0
license: MIT license: MIT
description: A pipeline for routing LLM requests using RouteLLM framework, compatible with OpenAI API. description: A pipeline for routing LLM requests using RouteLLM framework, compatible with OpenAI API.
requirements: routellm, pydantic, requests requirements: routellm, pydantic, requests
@ -10,7 +10,6 @@ requirements: routellm, pydantic, requests
from typing import List, Union, Generator, Iterator from typing import List, Union, Generator, Iterator
from pydantic import BaseModel from pydantic import BaseModel
import os
import logging import logging
from routellm.controller import Controller from routellm.controller import Controller
@ -31,38 +30,19 @@ class Pipeline:
self.valves = self.Valves() self.valves = self.Valves()
self.controller = None self.controller = None
# Set the environment variables for API keys and base URLs
self._set_environment_variables()
self._initialize_controller() self._initialize_controller()
def _set_environment_variables(self):
os.environ["OPENAI_API_KEY"] = self.valves.ROUTELLM_STRONG_API_KEY
logging.info(f"Setting OPENAI_API_KEY to: {os.environ['OPENAI_API_KEY']}")
os.environ["WEAK_MODEL_API_KEY"] = self.valves.ROUTELLM_WEAK_API_KEY
logging.info(f"Setting WEAK_MODEL_API_KEY to: {os.environ['WEAK_MODEL_API_KEY']}")
if self.valves.ROUTELLM_STRONG_BASE_URL:
os.environ['OPENAI_BASE_URL'] = self.valves.ROUTELLM_STRONG_BASE_URL
logging.info(f"Setting OPENAI_BASE_URL to: {os.environ['OPENAI_BASE_URL']}")
if self.valves.ROUTELLM_WEAK_BASE_URL:
os.environ['WEAK_MODEL_BASE_URL'] = self.valves.ROUTELLM_WEAK_BASE_URL
logging.info(f"Setting WEAK_MODEL_BASE_URL to: {os.environ['WEAK_MODEL_BASE_URL']}")
def pipelines(self) -> List[dict]: def pipelines(self) -> List[dict]:
return [{"id": f"routellm.{self.valves.ROUTELLM_ROUTER}", "name": f"RouteLLM/{self.valves.ROUTELLM_ROUTER}"}] return [{"id": f"routellm.{self.valves.ROUTELLM_ROUTER}", "name": f"RouteLLM/{self.valves.ROUTELLM_ROUTER}"}]
async def on_startup(self): async def on_startup(self):
logging.info(f"on_startup:{__name__}") logging.info(f"on_startup: {__name__}")
async def on_shutdown(self): async def on_shutdown(self):
logging.info(f"on_shutdown:{__name__}") logging.info(f"on_shutdown: {__name__}")
async def on_valves_updated(self): async def on_valves_updated(self):
logging.info(f"on_valves_updated:{__name__}") logging.info(f"on_valves_updated: {__name__}")
self._set_environment_variables()
self._initialize_controller() self._initialize_controller()
def _initialize_controller(self): def _initialize_controller(self):
@ -70,12 +50,10 @@ class Pipeline:
strong_model = self.valves.ROUTELLM_STRONG_MODEL strong_model = self.valves.ROUTELLM_STRONG_MODEL
weak_model = self.valves.ROUTELLM_WEAK_MODEL weak_model = self.valves.ROUTELLM_WEAK_MODEL
# Adjust model names if base URLs are provided # Set the API keys as environment variables
if self.valves.ROUTELLM_STRONG_BASE_URL: import os
strong_model = f"openai/{strong_model}" os.environ["OPENAI_API_KEY"] = self.valves.ROUTELLM_STRONG_API_KEY
if self.valves.ROUTELLM_WEAK_BASE_URL:
weak_model = f"openai/{weak_model}"
self.controller = Controller( self.controller = Controller(
routers=[self.valves.ROUTELLM_ROUTER], routers=[self.valves.ROUTELLM_ROUTER],
strong_model=strong_model, strong_model=strong_model,
@ -93,12 +71,19 @@ class Pipeline:
return "Error: RouteLLM controller not initialized. Please update valves with valid API keys and configuration." return "Error: RouteLLM controller not initialized. Please update valves with valid API keys and configuration."
try: try:
response = self.controller.chat.completions.create( model_name = f"router-{self.valves.ROUTELLM_ROUTER}-{self.valves.ROUTELLM_THRESHOLD}"
model=f"router-{self.valves.ROUTELLM_ROUTER}-{self.valves.ROUTELLM_THRESHOLD}",
# Prepare parameters, excluding 'model' and 'messages' if they're in body
params = {k: v for k, v in body.items() if k not in ['model', 'messages'] and v is not None}
# Ensure 'user' is a string if present
if 'user' in params and not isinstance(params['user'], str):
params['user'] = str(params['user'])
response = self.controller.completion(
model=model_name,
messages=messages, messages=messages,
max_tokens=body.get("max_tokens", 4096), **params
temperature=body.get("temperature", 0.8),
stream=body.get("stream", False),
) )
if body.get("stream", False): if body.get("stream", False):