From 5e445df828ee4c652c8f5e1b3773a4c9d4b03c89 Mon Sep 17 00:00:00 2001 From: AkaCyberMac Date: Tue, 26 Nov 2024 14:50:09 +0200 Subject: [PATCH] add style option in prompt --- .../integrations/recraft_pipeline.py | 37 ++++++++++++++++++- 1 file changed, 35 insertions(+), 2 deletions(-) diff --git a/examples/pipelines/integrations/recraft_pipeline.py b/examples/pipelines/integrations/recraft_pipeline.py index 46d35d7..77e38a4 100644 --- a/examples/pipelines/integrations/recraft_pipeline.py +++ b/examples/pipelines/integrations/recraft_pipeline.py @@ -14,6 +14,8 @@ from typing import List, Union, Generator, Iterator from pydantic import BaseModel from openai import OpenAI import os +import re +from difflib import get_close_matches class Pipeline: class Valves(BaseModel): @@ -23,6 +25,12 @@ class Pipeline: self.name = "Recraft AI Pipeline" self.valves = self.Valves(RECRAFT_API_TOKEN=os.getenv("RECRAFT_API_TOKEN", "")) self.client = None + self.available_styles = [ + "realistic_image", + "digital_illustration", + "vector_illustration", + "icon" + ] async def on_startup(self): print(f"on_startup:{__name__}") @@ -34,16 +42,41 @@ class Pipeline: async def on_shutdown(self): print(f"on_shutdown:{__name__}") + def get_closest_style(self, input_style: str) -> str: + # Convert input and available styles to lowercase for better matching + input_style = input_style.lower() + style_map = {s.lower(): s for s in self.available_styles} + + # Try to find close matches + matches = get_close_matches(input_style, style_map.keys(), n=1, cutoff=0.6) + + if matches: + closest = matches[0] + print(f"Using style '{style_map[closest]}' for input '{input_style}'") + return style_map[closest] + return "realistic_image" # default fallback + def pipe( self, user_message: str, model_id: str, messages: List[dict], body: dict ) -> Union[str, Generator, Iterator]: print(f"pipe:{__name__}") try: + # Extract style from prompt if provided in [style] format + style_match = re.search(r'\[(.*?)\]', user_message) + selected_style = self.get_closest_style(style_match.group(1)) if style_match else "realistic_image" + + # Clean the prompt by removing the style specification + clean_prompt = re.sub(r'\[.*?\]', '', user_message).strip() + + # Select model based on style + model = 'recraft20b' if selected_style == 'icon' else 'recraftv3' + response = self.client.images.generate( - prompt=user_message, - style='realistic_image', + prompt=clean_prompt, + style=selected_style, size='1280x1024', + model=model, ) print(response)