diff --git a/backend/apps/images/main.py b/backend/apps/images/main.py index fd72b203c..88cecc940 100644 --- a/backend/apps/images/main.py +++ b/backend/apps/images/main.py @@ -24,6 +24,7 @@ from utils.misc import calculate_sha256 from typing import Optional from pydantic import BaseModel from pathlib import Path +import mimetypes import uuid import base64 import json @@ -315,38 +316,50 @@ class GenerateImageForm(BaseModel): def save_b64_image(b64_str): - image_id = str(uuid.uuid4()) - file_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.png") - try: - # Split the base64 string to get the actual image data - img_data = base64.b64decode(b64_str) + header, encoded = b64_str.split(",", 1) + mime_type = header.split(";")[0] - # Write the image data to a file + img_data = base64.b64decode(encoded) + + image_id = str(uuid.uuid4()) + image_format = mimetypes.guess_extension(mime_type) + + image_filename = f"{image_id}{image_format}" + file_path = IMAGE_CACHE_DIR / f"{image_filename}" with open(file_path, "wb") as f: f.write(img_data) - - return image_id + return image_filename except Exception as e: - log.error(f"Error saving image: {e}") + log.exception(f"Error saving image: {e}") return None def save_url_image(url): image_id = str(uuid.uuid4()) - file_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.png") - try: r = requests.get(url) r.raise_for_status() + if r.headers["content-type"].split("/")[0] == "image": - with open(file_path, "wb") as image_file: - image_file.write(r.content) + mime_type = r.headers["content-type"] + image_format = mimetypes.guess_extension(mime_type) + + if not image_format: + raise ValueError("Could not determine image type from MIME type") + + file_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}{image_format}") + with open(file_path, "wb") as image_file: + for chunk in r.iter_content(chunk_size=8192): + image_file.write(chunk) + return image_id, image_format + else: + log.error(f"Url does not point to an image.") + return None, None - return image_id except Exception as e: log.exception(f"Error saving image: {e}") - return None + return None, None @app.post("/generations") @@ -385,8 +398,8 @@ def generate_image( images = [] for image in res["data"]: - image_id = save_b64_image(image["b64_json"]) - images.append({"url": f"/cache/image/generations/{image_id}.png"}) + image_filename = save_b64_image(image["b64_json"]) + images.append({"url": f"/cache/image/generations/{image_filename}"}) file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json") with open(file_body_path, "w") as f: @@ -422,8 +435,10 @@ def generate_image( images = [] for image in res["data"]: - image_id = save_url_image(image["url"]) - images.append({"url": f"/cache/image/generations/{image_id}.png"}) + image_id, image_format = save_url_image(image["url"]) + images.append( + {"url": f"/cache/image/generations/{image_id}{image_format}"} + ) file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json") with open(file_body_path, "w") as f: @@ -460,8 +475,8 @@ def generate_image( images = [] for image in res["images"]: - image_id = save_b64_image(image) - images.append({"url": f"/cache/image/generations/{image_id}.png"}) + image_filename = save_b64_image(image) + images.append({"url": f"/cache/image/generations/{image_filename}"}) file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json") with open(file_body_path, "w") as f: