From 3321a1b922588dfc7d137646c79ed69a3b327aeb Mon Sep 17 00:00:00 2001 From: Yanyutin753 <132346501+Yanyutin753@users.noreply.github.com> Date: Sun, 28 Apr 2024 12:00:52 +0800 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20expend=20the=20image=20format=20typ?= =?UTF-8?q?e=20after=20the=20file=20is=20downloaded?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/apps/images/main.py | 60 +++++++++++++++++++++++-------------- 1 file changed, 38 insertions(+), 22 deletions(-) diff --git a/backend/apps/images/main.py b/backend/apps/images/main.py index fd72b203c..8c5457e8a 100644 --- a/backend/apps/images/main.py +++ b/backend/apps/images/main.py @@ -24,6 +24,7 @@ from utils.misc import calculate_sha256 from typing import Optional from pydantic import BaseModel from pathlib import Path +import mimetypes import uuid import base64 import json @@ -315,38 +316,47 @@ class GenerateImageForm(BaseModel): def save_b64_image(b64_str): - image_id = str(uuid.uuid4()) - file_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.png") - try: - # Split the base64 string to get the actual image data - img_data = base64.b64decode(b64_str) + header, encoded = b64_str.split(",", 1) + mime_type = header.split(";")[0] - # Write the image data to a file + image_format = mimetypes.guess_extension(mime_type) + img_data = base64.b64decode(encoded) + image_id = str(uuid.uuid4()) + file_path = IMAGE_CACHE_DIR / f"{image_id}{image_format}" with open(file_path, "wb") as f: f.write(img_data) - - return image_id + return image_id, image_format except Exception as e: - log.error(f"Error saving image: {e}") - return None + log.exception(f"Error saving image: {e}") + return None, None def save_url_image(url): image_id = str(uuid.uuid4()) - file_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.png") - try: r = requests.get(url) r.raise_for_status() + if r.headers["content-type"].split("/")[0] == "image": - with open(file_path, "wb") as image_file: - image_file.write(r.content) + mime_type = r.headers["content-type"] + image_format = mimetypes.guess_extension(mime_type) + + if not image_format: + raise ValueError("Could not determine image type from MIME type") + + file_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}{image_format}") + with open(file_path, "wb") as image_file: + for chunk in r.iter_content(chunk_size=8192): + image_file.write(chunk) + return image_id, image_format + else: + log.error(f"Url does not point to an image.") + return None, None - return image_id except Exception as e: log.exception(f"Error saving image: {e}") - return None + return None, None @app.post("/generations") @@ -385,8 +395,10 @@ def generate_image( images = [] for image in res["data"]: - image_id = save_b64_image(image["b64_json"]) - images.append({"url": f"/cache/image/generations/{image_id}.png"}) + image_id, image_format = save_b64_image(image["b64_json"]) + images.append( + {"url": f"/cache/image/generations/{image_id}{image_format}"} + ) file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json") with open(file_body_path, "w") as f: @@ -422,8 +434,10 @@ def generate_image( images = [] for image in res["data"]: - image_id = save_url_image(image["url"]) - images.append({"url": f"/cache/image/generations/{image_id}.png"}) + image_id, image_format = save_url_image(image["url"]) + images.append( + {"url": f"/cache/image/generations/{image_id}{image_format}"} + ) file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json") with open(file_body_path, "w") as f: @@ -460,8 +474,10 @@ def generate_image( images = [] for image in res["images"]: - image_id = save_b64_image(image) - images.append({"url": f"/cache/image/generations/{image_id}.png"}) + image_id, image_format = save_b64_image(image) + images.append( + {"url": f"/cache/image/generations/{image_id}{image_format}"} + ) file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json") with open(file_body_path, "w") as f: