mirror of
https://github.com/deepseek-ai/DeepSeek-VL
synced 2024-11-22 11:27:51 +00:00
243 lines
8.0 KiB
Python
243 lines
8.0 KiB
Python
# Copyright (c) 2023-2024 DeepSeek.
|
|
#
|
|
# Permission is hereby granted, free of charge, to any person obtaining a copy of
|
|
# this software and associated documentation files (the "Software"), to deal in
|
|
# the Software without restriction, including without limitation the rights to
|
|
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
|
# the Software, and to permit persons to whom the Software is furnished to do so,
|
|
# subject to the following conditions:
|
|
#
|
|
# The above copyright notice and this permission notice shall be included in all
|
|
# copies or substantial portions of the Software.
|
|
#
|
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
|
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
|
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
|
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
|
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
|
|
|
from typing import Dict, List, Literal, Optional, Tuple, Union
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torchvision.transforms
|
|
from einops import rearrange
|
|
|
|
from deepseek_vl.models.sam import create_sam_vit
|
|
from deepseek_vl.models.siglip_vit import create_siglip_vit
|
|
|
|
|
|
class CLIPVisionTower(nn.Module):
|
|
def __init__(
|
|
self,
|
|
model_name: str = "siglip_large_patch16_384",
|
|
image_size: Union[Tuple[int, int], int] = 336,
|
|
select_feature: str = "patch",
|
|
select_layer: int = -2,
|
|
select_layers: list = None,
|
|
ckpt_path: str = "",
|
|
pixel_mean: Optional[List[float]] = None,
|
|
pixel_std: Optional[List[float]] = None,
|
|
**kwargs,
|
|
):
|
|
super().__init__()
|
|
|
|
self.model_name = model_name
|
|
self.select_feature = select_feature
|
|
self.select_layer = select_layer
|
|
self.select_layers = select_layers
|
|
|
|
vision_tower_params = {
|
|
"model_name": model_name,
|
|
"image_size": image_size,
|
|
"ckpt_path": ckpt_path,
|
|
"select_layer": select_layer,
|
|
}
|
|
vision_tower_params.update(kwargs)
|
|
self.vision_tower, self.forward_kwargs = self.build_vision_tower(
|
|
vision_tower_params
|
|
)
|
|
|
|
if pixel_mean is not None and pixel_std is not None:
|
|
image_norm = torchvision.transforms.Normalize(
|
|
mean=pixel_mean, std=pixel_std
|
|
)
|
|
else:
|
|
image_norm = None
|
|
|
|
self.image_norm = image_norm
|
|
|
|
def build_vision_tower(self, vision_tower_params):
|
|
if self.model_name.startswith("siglip"):
|
|
self.select_feature = "same"
|
|
vision_tower = create_siglip_vit(**vision_tower_params)
|
|
forward_kwargs = dict()
|
|
|
|
elif self.model_name.startswith("sam"):
|
|
vision_tower = create_sam_vit(**vision_tower_params)
|
|
forward_kwargs = dict()
|
|
|
|
else: # huggingface
|
|
from transformers import CLIPVisionModel
|
|
|
|
vision_tower = CLIPVisionModel.from_pretrained(**vision_tower_params)
|
|
forward_kwargs = dict(output_hidden_states=True)
|
|
|
|
return vision_tower, forward_kwargs
|
|
|
|
def feature_select(self, image_forward_outs):
|
|
if isinstance(image_forward_outs, torch.Tensor):
|
|
# the output has been the self.select_layer"s features
|
|
image_features = image_forward_outs
|
|
else:
|
|
image_features = image_forward_outs.hidden_states[self.select_layer]
|
|
|
|
if self.select_feature == "patch":
|
|
# if the output has cls_token
|
|
image_features = image_features[:, 1:]
|
|
elif self.select_feature == "cls_patch":
|
|
image_features = image_features
|
|
elif self.select_feature == "same":
|
|
image_features = image_features
|
|
|
|
else:
|
|
raise ValueError(f"Unexpected select feature: {self.select_feature}")
|
|
return image_features
|
|
|
|
def forward(self, images):
|
|
"""
|
|
|
|
Args:
|
|
images (torch.Tensor): [b, 3, H, W]
|
|
|
|
Returns:
|
|
image_features (torch.Tensor): [b, n_patch, d]
|
|
"""
|
|
|
|
if self.image_norm is not None:
|
|
images = self.image_norm(images)
|
|
|
|
image_forward_outs = self.vision_tower(images, **self.forward_kwargs)
|
|
image_features = self.feature_select(image_forward_outs)
|
|
return image_features
|
|
|
|
|
|
class HybridVisionTower(nn.Module):
|
|
def __init__(
|
|
self,
|
|
high_res_cfg: Dict,
|
|
low_res_cfg: Dict,
|
|
freeze_high: bool = False,
|
|
freeze_low: bool = False,
|
|
concat_type: Literal["feature", "sequence", "add", "tuple"] = "tuple",
|
|
**ignore_kwargs,
|
|
):
|
|
super().__init__()
|
|
|
|
self.vision_tower_high = CLIPVisionTower(**high_res_cfg)
|
|
self.vision_tower_low = CLIPVisionTower(**low_res_cfg)
|
|
self.low_res_size = low_res_cfg["image_size"]
|
|
self.concat_type = concat_type
|
|
|
|
self.high_layer_norm = nn.LayerNorm(high_res_cfg.get("output_dim", 1024))
|
|
self.low_layer_norm = nn.LayerNorm(low_res_cfg.get("output_dim", 1024))
|
|
|
|
if freeze_high:
|
|
for p_name, p in self.vision_tower_high.named_parameters():
|
|
p.requires_grad = False
|
|
self.vision_tower_high = self.vision_tower_high.eval()
|
|
else:
|
|
# train donwsamples and neck
|
|
for p_name, p in self.vision_tower_high.named_parameters():
|
|
if "downsamples" in p_name or "neck" in p_name:
|
|
p.requires_grad = True
|
|
else:
|
|
p.requires_grad = False
|
|
|
|
if freeze_low:
|
|
for p in self.vision_tower_low.parameters():
|
|
p.requires_grad = False
|
|
self.vision_tower_low = self.vision_tower_low.eval()
|
|
|
|
self.resize = torchvision.transforms.Resize(self.low_res_size, antialias=True)
|
|
|
|
def forward(self, images: torch.Tensor):
|
|
"""
|
|
|
|
Args:
|
|
images (torch.Tensor): [bs, 3, H, W]
|
|
|
|
Returns:
|
|
res (torch.Tensor): [bs, t, c]
|
|
"""
|
|
|
|
# [bs, c, h, w]
|
|
high_images = images
|
|
|
|
# [bs, c, h_low, w_low]
|
|
low_images = self.resize(images)
|
|
|
|
# separately run two vision towers
|
|
# run high_res vision tower
|
|
high_res = self.vision_tower_high(high_images)
|
|
# [bs, c, h, w] -> [bs, h*w, c]
|
|
high_res = rearrange(high_res, "b c h w -> b (h w) c")
|
|
# run low_res vision tower
|
|
low_res = self.vision_tower_low(low_images)
|
|
|
|
if self.concat_type == "feature":
|
|
images_features = torch.cat([high_res, low_res], dim=-1)
|
|
elif self.concat_type == "sequence":
|
|
images_features = torch.cat([high_res, low_res], dim=1)
|
|
elif self.concat_type == "add":
|
|
images_features = high_res + low_res
|
|
elif self.concat_type == "tuple":
|
|
images_features = (high_res, low_res)
|
|
|
|
else:
|
|
raise ValueError(
|
|
"Currently only support `feature`, `sequence`, `add` and `tuple` concat type."
|
|
)
|
|
|
|
return images_features
|
|
|
|
|
|
if __name__ == "__main__":
|
|
image_size = 1024
|
|
x = torch.zeros(2, 3, image_size, image_size).bfloat16().cuda()
|
|
|
|
high_res_cfg = dict(
|
|
model_name="sam_b_downsample",
|
|
select_feature="same",
|
|
image_size=image_size,
|
|
pixel_mean=(0.48145466, 0.4578275, 0.40821073),
|
|
pixel_std=(0.26862954, 0.26130258, 0.27577711),
|
|
select_layer=-1,
|
|
ckpt_path="",
|
|
)
|
|
|
|
low_res_cfg = dict(
|
|
model_name="siglip_large_patch16_384",
|
|
select_feature="same",
|
|
image_size=384,
|
|
pixel_mean=(0.5, 0.5, 0.5),
|
|
pixel_std=(0.5, 0.5, 0.5),
|
|
select_layer=-1,
|
|
ckpt_path="",
|
|
)
|
|
|
|
net = (
|
|
HybridVisionTower(
|
|
high_res_cfg=high_res_cfg,
|
|
low_res_cfg=low_res_cfg,
|
|
freeze_high=True,
|
|
freeze_low=True,
|
|
concat_type="tuple",
|
|
)
|
|
.bfloat16()
|
|
.cuda()
|
|
)
|
|
high_x, low_x = net(x)
|
|
print(x.shape, high_x.shape, low_x.shape)
|