enh: litellm manifold

Co-Authored-By: Artur Zdolinski <15941777+azdolinski@users.noreply.github.com>
This commit is contained in:
Timothy J. Baek 2024-06-03 13:24:57 -07:00
parent 8b5e0a05e9
commit 17bc338c4a

View File

@ -11,12 +11,15 @@ from typing import List, Union, Generator, Iterator
from schemas import OpenAIChatMessage from schemas import OpenAIChatMessage
from pydantic import BaseModel from pydantic import BaseModel
import requests import requests
import os
class Pipeline: class Pipeline:
class Valves(BaseModel): class Valves(BaseModel):
LITELLM_BASE_URL: str LITELLM_BASE_URL: str = ""
LITELLM_API_KEY: str = ""
LITELLM_PIPELINE_DEBUG: bool = False
def __init__(self): def __init__(self):
# You can also set the pipelines that are available in this pipeline. # You can also set the pipelines that are available in this pipeline.
@ -34,7 +37,15 @@ class Pipeline:
self.name = "LiteLLM: " self.name = "LiteLLM: "
# Initialize rate limits # Initialize rate limits
self.valves = self.Valves(**{"LITELLM_BASE_URL": "http://localhost:4001"}) self.valves = self.Valves(
**{
"LITELLM_BASE_URL": os.getenv(
"LITELLM_BASE_URL", "http://localhost:4001"
),
"LITELLM_API_KEY": os.getenv("LITELLM_API_KEY", "your-api-key-here"),
"LITELLM_PIPELINE_DEBUG": os.getenv("LITELLM_PIPELINE_DEBUG", False),
}
)
self.pipelines = [] self.pipelines = []
pass pass
@ -55,9 +66,16 @@ class Pipeline:
pass pass
def get_litellm_models(self): def get_litellm_models(self):
headers = {}
if self.valves.LITELLM_API_KEY:
headers["Authorization"] = f"Bearer {self.valves.LITELLM_API_KEY}"
if self.valves.LITELLM_BASE_URL: if self.valves.LITELLM_BASE_URL:
try: try:
r = requests.get(f"{self.valves.LITELLM_BASE_URL}/v1/models") r = requests.get(
f"{self.valves.LITELLM_BASE_URL}/v1/models", headers=headers
)
models = r.json() models = r.json()
return [ return [
{ {
@ -86,10 +104,20 @@ class Pipeline:
print(f"# Message: {user_message}") print(f"# Message: {user_message}")
print("######################################") print("######################################")
headers = {}
if self.valves.LITELLM_API_KEY:
headers["Authorization"] = f"Bearer {self.valves.LITELLM_API_KEY}"
try: try:
payload = {**body, "model": model_id, "user_id": body["user"]["id"]}
payload.pop("chat_id", None)
payload.pop("user", None)
payload.pop("title", None)
r = requests.post( r = requests.post(
url=f"{self.valves.LITELLM_BASE_URL}/v1/chat/completions", url=f"{self.valves.LITELLM_BASE_URL}/v1/chat/completions",
json={**body, "model": model_id, "user_id": body["user"]["id"]}, json=payload,
headers=headers,
stream=True, stream=True,
) )