mirror of
https://github.com/deepseek-ai/Janus
synced 2024-12-28 14:52:12 +00:00
209 lines
6.6 KiB
Python
209 lines
6.6 KiB
Python
|
# Copyright (c) 2023-2024 DeepSeek.
|
||
|
#
|
||
|
# Permission is hereby granted, free of charge, to any person obtaining a copy of
|
||
|
# this software and associated documentation files (the "Software"), to deal in
|
||
|
# the Software without restriction, including without limitation the rights to
|
||
|
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
||
|
# the Software, and to permit persons to whom the Software is furnished to do so,
|
||
|
# subject to the following conditions:
|
||
|
#
|
||
|
# The above copyright notice and this permission notice shall be included in all
|
||
|
# copies or substantial portions of the Software.
|
||
|
#
|
||
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
||
|
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
||
|
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
||
|
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
||
|
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||
|
|
||
|
from typing import List, Tuple, Union
|
||
|
|
||
|
import numpy as np
|
||
|
import torch
|
||
|
import torchvision
|
||
|
import torchvision.transforms.functional
|
||
|
from PIL import Image
|
||
|
from transformers import AutoImageProcessor, PretrainedConfig
|
||
|
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
|
||
|
from transformers.image_utils import to_numpy_array
|
||
|
from transformers.utils import logging
|
||
|
|
||
|
logger = logging.get_logger(__name__)
|
||
|
|
||
|
ImageType = Union[np.ndarray, torch.Tensor, Image.Image]
|
||
|
IMAGENET_MEAN = (0.48145466, 0.4578275, 0.40821073)
|
||
|
IMAGENET_STD = (0.26862954, 0.26130258, 0.27577711)
|
||
|
IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5)
|
||
|
IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5)
|
||
|
|
||
|
|
||
|
def expand2square(pil_img, background_color):
|
||
|
width, height = pil_img.size
|
||
|
if width == height:
|
||
|
return pil_img
|
||
|
elif width > height:
|
||
|
result = Image.new(pil_img.mode, (width, width), background_color)
|
||
|
result.paste(pil_img, (0, (width - height) // 2))
|
||
|
return result
|
||
|
else:
|
||
|
result = Image.new(pil_img.mode, (height, height), background_color)
|
||
|
result.paste(pil_img, ((height - width) // 2, 0))
|
||
|
return result
|
||
|
|
||
|
|
||
|
class VLMImageProcessorConfig(PretrainedConfig):
|
||
|
model_type = "deepseek_vlm"
|
||
|
image_size: int
|
||
|
min_size: int
|
||
|
image_mean: Union[Tuple[float, float, float], List[float]]
|
||
|
image_std: Union[Tuple[float, float, float], List[float]]
|
||
|
rescale_factor: float
|
||
|
do_normalize: bool
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
image_size: int,
|
||
|
min_size: int = 14,
|
||
|
image_mean: Union[Tuple[float, float, float], List[float]] = (
|
||
|
0.48145466,
|
||
|
0.4578275,
|
||
|
0.40821073,
|
||
|
),
|
||
|
image_std: Union[Tuple[float, float, float], List[float]] = (
|
||
|
0.26862954,
|
||
|
0.26130258,
|
||
|
0.27577711,
|
||
|
),
|
||
|
rescale_factor: float = 1.0 / 255.0,
|
||
|
do_normalize: bool = True,
|
||
|
**kwargs,
|
||
|
):
|
||
|
self.image_size = image_size
|
||
|
self.min_size = min_size
|
||
|
self.image_mean = image_mean
|
||
|
self.image_std = image_std
|
||
|
self.rescale_factor = rescale_factor
|
||
|
self.do_normalize = do_normalize
|
||
|
|
||
|
super().__init__(**kwargs)
|
||
|
|
||
|
|
||
|
class VLMImageProcessor(BaseImageProcessor):
|
||
|
model_input_names = ["pixel_values"]
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
image_size: int,
|
||
|
min_size: int = 14,
|
||
|
image_mean: Union[Tuple[float, float, float], List[float]] = (
|
||
|
0.48145466,
|
||
|
0.4578275,
|
||
|
0.40821073,
|
||
|
),
|
||
|
image_std: Union[Tuple[float, float, float], List[float]] = (
|
||
|
0.26862954,
|
||
|
0.26130258,
|
||
|
0.27577711,
|
||
|
),
|
||
|
rescale_factor: float = 1.0 / 255.0,
|
||
|
do_normalize: bool = True,
|
||
|
**kwargs,
|
||
|
):
|
||
|
super().__init__(**kwargs)
|
||
|
|
||
|
self.image_size = image_size
|
||
|
self.rescale_factor = rescale_factor
|
||
|
self.image_mean = image_mean
|
||
|
self.image_std = image_std
|
||
|
self.min_size = min_size
|
||
|
self.do_normalize = do_normalize
|
||
|
|
||
|
if image_mean is None:
|
||
|
self.background_color = (127, 127, 127)
|
||
|
else:
|
||
|
self.background_color = tuple([int(x * 255) for x in image_mean])
|
||
|
|
||
|
def resize(self, pil_img: Image) -> np.ndarray:
|
||
|
"""
|
||
|
|
||
|
Args:
|
||
|
pil_img (PIL.Image): [H, W, 3] in PIL.Image in RGB
|
||
|
|
||
|
Returns:
|
||
|
x (np.ndarray): [3, self.image_size, self.image_size]
|
||
|
"""
|
||
|
|
||
|
width, height = pil_img.size
|
||
|
max_size = max(width, height)
|
||
|
|
||
|
size = [
|
||
|
max(int(height / max_size * self.image_size), self.min_size),
|
||
|
max(int(width / max_size * self.image_size), self.min_size),
|
||
|
]
|
||
|
|
||
|
if width <= 0 or height <= 0 or size[0] <= 0 or size[1] <= 0:
|
||
|
print(f"orig size = {pil_img.size}, new size = {size}")
|
||
|
raise ValueError("Invalid size!")
|
||
|
|
||
|
pil_img = torchvision.transforms.functional.resize(
|
||
|
pil_img,
|
||
|
size,
|
||
|
interpolation=torchvision.transforms.functional.InterpolationMode.BICUBIC,
|
||
|
antialias=True,
|
||
|
)
|
||
|
|
||
|
pil_img = expand2square(pil_img, self.background_color)
|
||
|
x = to_numpy_array(pil_img)
|
||
|
|
||
|
# [H, W, 3] -> [3, H, W]
|
||
|
x = np.transpose(x, (2, 0, 1))
|
||
|
|
||
|
return x
|
||
|
|
||
|
def preprocess(self, images, return_tensors: str = "pt", **kwargs) -> BatchFeature:
|
||
|
# resize and pad to [self.image_size, self.image_size]
|
||
|
# then convert from [H, W, 3] to [3, H, W]
|
||
|
images: List[np.ndarray] = [self.resize(image) for image in images]
|
||
|
|
||
|
# resacle from [0, 255] -> [0, 1]
|
||
|
images = [
|
||
|
self.rescale(
|
||
|
image=image,
|
||
|
scale=self.rescale_factor,
|
||
|
input_data_format="channels_first",
|
||
|
)
|
||
|
for image in images
|
||
|
]
|
||
|
|
||
|
# normalize
|
||
|
if self.do_normalize:
|
||
|
images = [
|
||
|
self.normalize(
|
||
|
image=image,
|
||
|
mean=self.image_mean,
|
||
|
std=self.image_std,
|
||
|
input_data_format="channels_first",
|
||
|
)
|
||
|
for image in images
|
||
|
]
|
||
|
|
||
|
data = {"pixel_values": images}
|
||
|
return BatchFeature(data=data, tensor_type=return_tensors)
|
||
|
|
||
|
@property
|
||
|
def default_shape(self):
|
||
|
return [3, self.image_size, self.image_size]
|
||
|
|
||
|
|
||
|
AutoImageProcessor.register(VLMImageProcessorConfig, VLMImageProcessor)
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
image_processor = VLMImageProcessor(
|
||
|
image_size=1024,
|
||
|
image_mean=IMAGENET_INCEPTION_MEAN,
|
||
|
image_std=IMAGENET_INCEPTION_STD,
|
||
|
do_normalize=True,
|
||
|
)
|