fix: openai pipeline

This commit is contained in:
Timothy J. Baek 2024-06-02 16:58:28 -07:00
parent e024afc81c
commit c2f5200906

View File

@ -1,9 +1,15 @@
from typing import List, Union, Generator, Iterator from typing import List, Union, Generator, Iterator
from schemas import OpenAIChatMessage from schemas import OpenAIChatMessage
from pydantic import BaseModel
import os
import requests import requests
class Pipeline: class Pipeline:
class Valves(BaseModel):
OPENAI_API_KEY: str = ""
pass
def __init__(self): def __init__(self):
# Optionally, you can set the id and name of the pipeline. # Optionally, you can set the id and name of the pipeline.
# Best practice is to not specify the id so that it can be automatically inferred from the filename, so that users can install multiple versions of the same pipeline. # Best practice is to not specify the id so that it can be automatically inferred from the filename, so that users can install multiple versions of the same pipeline.
@ -11,6 +17,13 @@ class Pipeline:
# The identifier must be an alphanumeric string that can include underscores or hyphens. It cannot contain spaces, special characters, slashes, or backslashes. # The identifier must be an alphanumeric string that can include underscores or hyphens. It cannot contain spaces, special characters, slashes, or backslashes.
# self.id = "openai_pipeline" # self.id = "openai_pipeline"
self.name = "OpenAI Pipeline" self.name = "OpenAI Pipeline"
self.valves = self.Valves(
**{
"OPENAI_API_KEY": os.getenv(
"OPENAI_API_KEY", "your-openai-api-key-here"
)
}
)
pass pass
async def on_startup(self): async def on_startup(self):
@ -39,10 +52,21 @@ class Pipeline:
headers["Authorization"] = f"Bearer {OPENAI_API_KEY}" headers["Authorization"] = f"Bearer {OPENAI_API_KEY}"
headers["Content-Type"] = "application/json" headers["Content-Type"] = "application/json"
payload = {**body, "model": MODEL}
if "user" in payload:
del payload["user"]
if "chat_id" in payload:
del payload["chat_id"]
if "title" in payload:
del payload["title"]
print(payload)
try: try:
r = requests.post( r = requests.post(
url="https://api.openai.com/v1/chat/completions", url="https://api.openai.com/v1/chat/completions",
json={**body, "model": MODEL}, json=payload,
headers=headers, headers=headers,
stream=True, stream=True,
) )