diff --git a/examples/pipelines/integrations/recraft_pipeline.py b/examples/pipelines/integrations/recraft_pipeline.py index 77e38a4..af02daf 100644 --- a/examples/pipelines/integrations/recraft_pipeline.py +++ b/examples/pipelines/integrations/recraft_pipeline.py @@ -31,6 +31,95 @@ class Pipeline: "vector_illustration", "icon" ] + self.style_substyles = { + "realistic_image": [ + "b_and_w", + "enterprise", + "evening_light", + "faded_nostalgia", + "forest_life", + "hard_flash", + "hdr", + "motion_blur", + "mystic_naturalism", + "natural_light", + "natural_tones", + "organic_calm", + "real_life_glow", + "retro_realism", + "retro_snapshot", + "studio_portrait", + "urban_drama", + "village_realism", + "warm_folk" + ], + "digital_illustration": [ + "2d_art_poster", + "2d_art_poster_2", + "engraving_color", + "grain", + "hand_drawn", + "hand_drawn_outline", + "handmade_3d", + "infantile_sketch", + "pixel_art", + "antiquarian", + "bold_fantasy", + "child_book", + "child_books", + "cover", + "crosshatch", + "digital_engraving", + "expressionism", + "freehand_details", + "grain_20", + "graphic_intensity", + "hard_comics", + "long_shadow", + "modern_folk", + "multicolor", + "neon_calm", + "noir", + "nostalgic_pastel", + "outline_details", + "pastel_gradient", + "pastel_sketch", + "pop_art", + "pop_renaissance", + "street_art", + "tablet_sketch", + "urban_glow", + "urban_sketching", + "vanilla_dreams", + "young_adult_book", + "young_adult_book_2" + ], + "vector_illustration": [ + "bold_stroke", + "chemistry", + "colored_stencil", + "contour_pop_art", + "cosmics", + "cutout", + "depressive", + "editorial", + "emotional_flat", + "engraving", + "infographical", + "line_art", + "line_circuit", + "linocut", + "marker_outline", + "mosaic", + "naivector", + "roundish_flat", + "segmented_colors", + "sharp_contrast", + "thin", + "vector_photo", + "vivid_shapes" + ] + } async def on_startup(self): print(f"on_startup:{__name__}") @@ -42,19 +131,45 @@ class Pipeline: async def on_shutdown(self): print(f"on_shutdown:{__name__}") - def get_closest_style(self, input_style: str) -> str: - # Convert input and available styles to lowercase for better matching - input_style = input_style.lower() + def get_style_and_substyle(self, input_text: str) -> tuple[str, str | None]: + """ + Find the best matching style and substyle from a single input. + Returns a tuple of (style, substyle) where substyle may be None. + """ + input_text = input_text.lower() + + # First try to match the main style style_map = {s.lower(): s for s in self.available_styles} + style_matches = get_close_matches(input_text, style_map.keys(), n=1, cutoff=0.6) - # Try to find close matches - matches = get_close_matches(input_style, style_map.keys(), n=1, cutoff=0.6) + # Create a map of all substyles to their parent styles + substyle_to_style = {} + for style, substyles in self.style_substyles.items(): + for substyle in substyles: + substyle_to_style[substyle.lower()] = (style, substyle) - if matches: - closest = matches[0] - print(f"Using style '{style_map[closest]}' for input '{input_style}'") - return style_map[closest] - return "realistic_image" # default fallback + # Try to match substyle + substyle_matches = get_close_matches(input_text, substyle_to_style.keys(), n=1, cutoff=0.6) + + # If we found a style match + if style_matches: + matched_style = style_map[style_matches[0]] + # If this style has substyles, try to find a default or matching substyle + if matched_style in self.style_substyles: + # If we also found a substyle match and it belongs to this style, use it + if substyle_matches and substyle_to_style[substyle_matches[0]][0] == matched_style: + return matched_style, substyle_to_style[substyle_matches[0]][1] + # Otherwise return None for substyle + return matched_style, None + return matched_style, None + + # If we found a substyle match, use its parent style + if substyle_matches: + style, substyle = substyle_to_style[substyle_matches[0]] + return style, substyle + + # Default to realistic_image with no substyle + return "realistic_image", None def pipe( self, user_message: str, model_id: str, messages: List[dict], body: dict @@ -62,22 +177,35 @@ class Pipeline: print(f"pipe:{__name__}") try: - # Extract style from prompt if provided in [style] format + # Extract style/substyle specification from prompt if provided in [text] format style_match = re.search(r'\[(.*?)\]', user_message) - selected_style = self.get_closest_style(style_match.group(1)) if style_match else "realistic_image" - # Clean the prompt by removing the style specification + # Get style and substyle from the input text + if style_match: + selected_style, selected_substyle = self.get_style_and_substyle(style_match.group(1)) + print(f"Matched style: {selected_style}, substyle: {selected_substyle}") + else: + selected_style, selected_substyle = "realistic_image", None + + # Clean the prompt by removing all bracketed content clean_prompt = re.sub(r'\[.*?\]', '', user_message).strip() # Select model based on style model = 'recraft20b' if selected_style == 'icon' else 'recraftv3' - response = self.client.images.generate( - prompt=clean_prompt, - style=selected_style, - size='1280x1024', - model=model, - ) + # Prepare request parameters + params = { + 'prompt': clean_prompt, + 'style': selected_style, + 'size': '1280x1024', + 'model': model, + } + + # Add substyle if specified and valid + if selected_substyle: + params['extra_body'] = {'substyle': selected_substyle} + + response = self.client.images.generate(**params) print(response) if response and response.data and len(response.data) > 0: