mirror of
https://github.com/open-webui/pipelines
synced 2025-05-11 08:01:08 +00:00
Add ImageGen exaple (Using DALL-E)
This commit is contained in:
parent
b4ed391fa5
commit
dbb68000f7
@ -0,0 +1,86 @@
|
||||
"""A manifold to integrate OpenAI's ImageGen models into Open-WebUI"""
|
||||
|
||||
from typing import List, Union, Generator, Iterator
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
class Pipeline:
|
||||
"""OpenAI ImageGen pipeline"""
|
||||
|
||||
class Valves(BaseModel):
|
||||
"""Options to change from the WebUI"""
|
||||
|
||||
OPENAI_API_BASE_URL: str = "https://api.openai.com/v1"
|
||||
OPENAI_API_KEY: str = ""
|
||||
IMAGE_SIZE: str = "1024x1024"
|
||||
NUM_IMAGES: int = 1
|
||||
|
||||
def __init__(self):
|
||||
self.type = "manifold"
|
||||
self.name = "ImageGen: "
|
||||
|
||||
self.valves = self.Valves()
|
||||
self.client = OpenAI(
|
||||
base_url=self.valves.OPENAI_API_BASE_URL,
|
||||
api_key=self.valves.OPENAI_API_KEY,
|
||||
)
|
||||
|
||||
self.pipelines = self.get_openai_assistants()
|
||||
|
||||
async def on_startup(self) -> None:
|
||||
"""This function is called when the server is started."""
|
||||
print(f"on_startup:{__name__}")
|
||||
|
||||
async def on_shutdown(self):
|
||||
"""This function is called when the server is stopped."""
|
||||
print(f"on_shutdown:{__name__}")
|
||||
|
||||
async def on_valves_updated(self):
|
||||
"""This function is called when the valves are updated."""
|
||||
print(f"on_valves_updated:{__name__}")
|
||||
self.client = OpenAI(
|
||||
base_url=self.valves.OPENAI_API_BASE_URL,
|
||||
api_key=self.valves.OPENAI_API_KEY,
|
||||
)
|
||||
self.pipelines = self.get_openai_assistants()
|
||||
|
||||
def get_openai_assistants(self) -> List[dict]:
|
||||
"""Get the available ImageGen models from OpenAI
|
||||
|
||||
Returns:
|
||||
List[dict]: The list of ImageGen models
|
||||
"""
|
||||
|
||||
if self.valves.OPENAI_API_KEY:
|
||||
models = self.client.models.list()
|
||||
return [
|
||||
{
|
||||
"id": model.id,
|
||||
"name": model.id,
|
||||
}
|
||||
for model in models
|
||||
if "dall-e" in model.id
|
||||
]
|
||||
|
||||
return []
|
||||
|
||||
def pipe(
|
||||
self, user_message: str, model_id: str, messages: List[dict], body: dict
|
||||
) -> Union[str, Generator, Iterator]:
|
||||
print(f"pipe:{__name__}")
|
||||
|
||||
response = self.client.images.generate(
|
||||
model=model_id,
|
||||
prompt=user_message,
|
||||
size=self.valves.IMAGE_SIZE,
|
||||
n=self.valves.NUM_IMAGES,
|
||||
)
|
||||
|
||||
message = ""
|
||||
for image in response.data:
|
||||
if image.url:
|
||||
message += "\n"
|
||||
|
||||
yield message
|
Loading…
Reference in New Issue
Block a user