add style option in prompt

This commit is contained in:
AkaCyberMac 2024-11-26 14:50:09 +02:00
parent 3526fd0368
commit 5e445df828

View File

@ -14,6 +14,8 @@ from typing import List, Union, Generator, Iterator
from pydantic import BaseModel from pydantic import BaseModel
from openai import OpenAI from openai import OpenAI
import os import os
import re
from difflib import get_close_matches
class Pipeline: class Pipeline:
class Valves(BaseModel): class Valves(BaseModel):
@ -23,6 +25,12 @@ class Pipeline:
self.name = "Recraft AI Pipeline" self.name = "Recraft AI Pipeline"
self.valves = self.Valves(RECRAFT_API_TOKEN=os.getenv("RECRAFT_API_TOKEN", "")) self.valves = self.Valves(RECRAFT_API_TOKEN=os.getenv("RECRAFT_API_TOKEN", ""))
self.client = None self.client = None
self.available_styles = [
"realistic_image",
"digital_illustration",
"vector_illustration",
"icon"
]
async def on_startup(self): async def on_startup(self):
print(f"on_startup:{__name__}") print(f"on_startup:{__name__}")
@ -34,16 +42,41 @@ class Pipeline:
async def on_shutdown(self): async def on_shutdown(self):
print(f"on_shutdown:{__name__}") print(f"on_shutdown:{__name__}")
def get_closest_style(self, input_style: str) -> str:
# Convert input and available styles to lowercase for better matching
input_style = input_style.lower()
style_map = {s.lower(): s for s in self.available_styles}
# Try to find close matches
matches = get_close_matches(input_style, style_map.keys(), n=1, cutoff=0.6)
if matches:
closest = matches[0]
print(f"Using style '{style_map[closest]}' for input '{input_style}'")
return style_map[closest]
return "realistic_image" # default fallback
def pipe( def pipe(
self, user_message: str, model_id: str, messages: List[dict], body: dict self, user_message: str, model_id: str, messages: List[dict], body: dict
) -> Union[str, Generator, Iterator]: ) -> Union[str, Generator, Iterator]:
print(f"pipe:{__name__}") print(f"pipe:{__name__}")
try: try:
# Extract style from prompt if provided in [style] format
style_match = re.search(r'\[(.*?)\]', user_message)
selected_style = self.get_closest_style(style_match.group(1)) if style_match else "realistic_image"
# Clean the prompt by removing the style specification
clean_prompt = re.sub(r'\[.*?\]', '', user_message).strip()
# Select model based on style
model = 'recraft20b' if selected_style == 'icon' else 'recraftv3'
response = self.client.images.generate( response = self.client.images.generate(
prompt=user_message, prompt=clean_prompt,
style='realistic_image', style=selected_style,
size='1280x1024', size='1280x1024',
model=model,
) )
print(response) print(response)