mirror of
https://github.com/deepseek-ai/Janus
synced 2024-12-28 14:52:12 +00:00
commit
e922b2b4a7
16
README.md
16
README.md
@ -264,6 +264,22 @@ python demo/app.py
|
|||||||
|
|
||||||
Have Fun!
|
Have Fun!
|
||||||
|
|
||||||
|
### FastAPI Demo
|
||||||
|
It's easy to run a FastAPI server to host an API server running the same functions as gradio.
|
||||||
|
|
||||||
|
To start FastAPI server, run the following command:
|
||||||
|
|
||||||
|
```
|
||||||
|
python demo/fastapi_app.py
|
||||||
|
```
|
||||||
|
|
||||||
|
To test the server, you can open another terminal and run:
|
||||||
|
|
||||||
|
```
|
||||||
|
python demo/fastapi_client.py
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
## 5. License
|
## 5. License
|
||||||
|
|
||||||
This code repository is licensed under [the MIT License](https://github.com/deepseek-ai/DeepSeek-LLM/blob/HEAD/LICENSE-CODE). The use of Janus models is subject to [DeepSeek Model License](https://github.com/deepseek-ai/DeepSeek-LLM/blob/HEAD/LICENSE-MODEL).
|
This code repository is licensed under [the MIT License](https://github.com/deepseek-ai/DeepSeek-LLM/blob/HEAD/LICENSE-CODE). The use of Janus models is subject to [DeepSeek Model License](https://github.com/deepseek-ai/DeepSeek-LLM/blob/HEAD/LICENSE-MODEL).
|
||||||
|
178
demo/fastapi_app.py
Normal file
178
demo/fastapi_app.py
Normal file
@ -0,0 +1,178 @@
|
|||||||
|
from fastapi import FastAPI, File, Form, UploadFile, HTTPException
|
||||||
|
from fastapi.responses import JSONResponse, StreamingResponse
|
||||||
|
import torch
|
||||||
|
from transformers import AutoConfig, AutoModelForCausalLM
|
||||||
|
from janus.models import MultiModalityCausalLM, VLChatProcessor
|
||||||
|
from PIL import Image
|
||||||
|
import numpy as np
|
||||||
|
import io
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
# Load model and processor
|
||||||
|
model_path = "deepseek-ai/Janus-1.3B"
|
||||||
|
config = AutoConfig.from_pretrained(model_path)
|
||||||
|
language_config = config.language_config
|
||||||
|
language_config._attn_implementation = 'eager'
|
||||||
|
vl_gpt = AutoModelForCausalLM.from_pretrained(model_path,
|
||||||
|
language_config=language_config,
|
||||||
|
trust_remote_code=True)
|
||||||
|
vl_gpt = vl_gpt.to(torch.bfloat16).cuda()
|
||||||
|
|
||||||
|
vl_chat_processor = VLChatProcessor.from_pretrained(model_path)
|
||||||
|
tokenizer = vl_chat_processor.tokenizer
|
||||||
|
cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def multimodal_understanding(image_data, question, seed, top_p, temperature):
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
|
|
||||||
|
conversation = [
|
||||||
|
{
|
||||||
|
"role": "User",
|
||||||
|
"content": f"<image_placeholder>\n{question}",
|
||||||
|
"images": [image_data],
|
||||||
|
},
|
||||||
|
{"role": "Assistant", "content": ""},
|
||||||
|
]
|
||||||
|
|
||||||
|
pil_images = [Image.open(io.BytesIO(image_data))]
|
||||||
|
prepare_inputs = vl_chat_processor(
|
||||||
|
conversations=conversation, images=pil_images, force_batchify=True
|
||||||
|
).to(cuda_device, dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16)
|
||||||
|
|
||||||
|
inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
|
||||||
|
outputs = vl_gpt.language_model.generate(
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
attention_mask=prepare_inputs.attention_mask,
|
||||||
|
pad_token_id=tokenizer.eos_token_id,
|
||||||
|
bos_token_id=tokenizer.bos_token_id,
|
||||||
|
eos_token_id=tokenizer.eos_token_id,
|
||||||
|
max_new_tokens=512,
|
||||||
|
do_sample=False if temperature == 0 else True,
|
||||||
|
use_cache=True,
|
||||||
|
temperature=temperature,
|
||||||
|
top_p=top_p,
|
||||||
|
)
|
||||||
|
|
||||||
|
answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
|
||||||
|
return answer
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/understand_image_and_question/")
|
||||||
|
async def understand_image_and_question(
|
||||||
|
file: UploadFile = File(...),
|
||||||
|
question: str = Form(...),
|
||||||
|
seed: int = Form(42),
|
||||||
|
top_p: float = Form(0.95),
|
||||||
|
temperature: float = Form(0.1)
|
||||||
|
):
|
||||||
|
image_data = await file.read()
|
||||||
|
response = multimodal_understanding(image_data, question, seed, top_p, temperature)
|
||||||
|
return JSONResponse({"response": response})
|
||||||
|
|
||||||
|
|
||||||
|
def generate(input_ids,
|
||||||
|
width,
|
||||||
|
height,
|
||||||
|
temperature: float = 1,
|
||||||
|
parallel_size: int = 5,
|
||||||
|
cfg_weight: float = 5,
|
||||||
|
image_token_num_per_image: int = 576,
|
||||||
|
patch_size: int = 16):
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int).to(cuda_device)
|
||||||
|
for i in range(parallel_size * 2):
|
||||||
|
tokens[i, :] = input_ids
|
||||||
|
if i % 2 != 0:
|
||||||
|
tokens[i, 1:-1] = vl_chat_processor.pad_id
|
||||||
|
inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens)
|
||||||
|
generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).to(cuda_device)
|
||||||
|
|
||||||
|
pkv = None
|
||||||
|
for i in range(image_token_num_per_image):
|
||||||
|
outputs = vl_gpt.language_model.model(inputs_embeds=inputs_embeds, use_cache=True, past_key_values=pkv)
|
||||||
|
pkv = outputs.past_key_values
|
||||||
|
hidden_states = outputs.last_hidden_state
|
||||||
|
logits = vl_gpt.gen_head(hidden_states[:, -1, :])
|
||||||
|
logit_cond = logits[0::2, :]
|
||||||
|
logit_uncond = logits[1::2, :]
|
||||||
|
logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond)
|
||||||
|
probs = torch.softmax(logits / temperature, dim=-1)
|
||||||
|
next_token = torch.multinomial(probs, num_samples=1)
|
||||||
|
generated_tokens[:, i] = next_token.squeeze(dim=-1)
|
||||||
|
next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)
|
||||||
|
img_embeds = vl_gpt.prepare_gen_img_embeds(next_token)
|
||||||
|
inputs_embeds = img_embeds.unsqueeze(dim=1)
|
||||||
|
patches = vl_gpt.gen_vision_model.decode_code(
|
||||||
|
generated_tokens.to(dtype=torch.int),
|
||||||
|
shape=[parallel_size, 8, width // patch_size, height // patch_size]
|
||||||
|
)
|
||||||
|
|
||||||
|
return generated_tokens.to(dtype=torch.int), patches
|
||||||
|
|
||||||
|
|
||||||
|
def unpack(dec, width, height, parallel_size=5):
|
||||||
|
dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)
|
||||||
|
dec = np.clip((dec + 1) / 2 * 255, 0, 255)
|
||||||
|
|
||||||
|
visual_img = np.zeros((parallel_size, width, height, 3), dtype=np.uint8)
|
||||||
|
visual_img[:, :, :] = dec
|
||||||
|
|
||||||
|
return visual_img
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def generate_image(prompt, seed, guidance):
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
seed = seed if seed is not None else 12345
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
width = 384
|
||||||
|
height = 384
|
||||||
|
parallel_size = 5
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
messages = [{'role': 'User', 'content': prompt}, {'role': 'Assistant', 'content': ''}]
|
||||||
|
text = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(
|
||||||
|
conversations=messages,
|
||||||
|
sft_format=vl_chat_processor.sft_format,
|
||||||
|
system_prompt=''
|
||||||
|
)
|
||||||
|
text = text + vl_chat_processor.image_start_tag
|
||||||
|
input_ids = torch.LongTensor(tokenizer.encode(text))
|
||||||
|
_, patches = generate(input_ids, width // 16 * 16, height // 16 * 16, cfg_weight=guidance, parallel_size=parallel_size)
|
||||||
|
images = unpack(patches, width // 16 * 16, height // 16 * 16)
|
||||||
|
|
||||||
|
return [Image.fromarray(images[i]).resize((1024, 1024), Image.LANCZOS) for i in range(parallel_size)]
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/generate_images/")
|
||||||
|
async def generate_images(
|
||||||
|
prompt: str = Form(...),
|
||||||
|
seed: int = Form(None),
|
||||||
|
guidance: float = Form(5.0),
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
images = generate_image(prompt, seed, guidance)
|
||||||
|
def image_stream():
|
||||||
|
for img in images:
|
||||||
|
buf = io.BytesIO()
|
||||||
|
img.save(buf, format='PNG')
|
||||||
|
buf.seek(0)
|
||||||
|
yield buf.read()
|
||||||
|
|
||||||
|
return StreamingResponse(image_stream(), media_type="multipart/related")
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=f"Image generation failed: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import uvicorn
|
||||||
|
uvicorn.run(app, host="0.0.0.0", port=8000)
|
78
demo/fastapi_client.py
Normal file
78
demo/fastapi_client.py
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
import requests
|
||||||
|
from PIL import Image
|
||||||
|
import io
|
||||||
|
# Endpoint URLs
|
||||||
|
understand_image_url = "http://localhost:8000/understand_image_and_question/"
|
||||||
|
generate_images_url = "http://localhost:8000/generate_images/"
|
||||||
|
|
||||||
|
# Use your image file path here
|
||||||
|
image_path = "images/equation.png"
|
||||||
|
|
||||||
|
# Function to call the image understanding endpoint
|
||||||
|
def understand_image_and_question(image_path, question, seed=42, top_p=0.95, temperature=0.1):
|
||||||
|
files = {'file': open(image_path, 'rb')}
|
||||||
|
data = {
|
||||||
|
'question': question,
|
||||||
|
'seed': seed,
|
||||||
|
'top_p': top_p,
|
||||||
|
'temperature': temperature
|
||||||
|
}
|
||||||
|
response = requests.post(understand_image_url, files=files, data=data)
|
||||||
|
response_data = response.json()
|
||||||
|
print("Image Understanding Response:", response_data['response'])
|
||||||
|
|
||||||
|
|
||||||
|
# Function to call the text-to-image generation endpoint
|
||||||
|
def generate_images(prompt, seed=None, guidance=5.0):
|
||||||
|
data = {
|
||||||
|
'prompt': prompt,
|
||||||
|
'seed': seed,
|
||||||
|
'guidance': guidance
|
||||||
|
}
|
||||||
|
response = requests.post(generate_images_url, data=data, stream=True)
|
||||||
|
|
||||||
|
if response.ok:
|
||||||
|
img_idx = 1
|
||||||
|
|
||||||
|
# We will create a new BytesIO for each image
|
||||||
|
buffers = {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
for chunk in response.iter_content(chunk_size=1024):
|
||||||
|
if chunk:
|
||||||
|
# Use a boundary detection to determine new image start
|
||||||
|
if img_idx not in buffers:
|
||||||
|
buffers[img_idx] = io.BytesIO()
|
||||||
|
|
||||||
|
buffers[img_idx].write(chunk)
|
||||||
|
|
||||||
|
# Attempt to open the image
|
||||||
|
try:
|
||||||
|
buffer = buffers[img_idx]
|
||||||
|
buffer.seek(0)
|
||||||
|
image = Image.open(buffer)
|
||||||
|
img_path = f"generated_image_{img_idx}.png"
|
||||||
|
image.save(img_path)
|
||||||
|
print(f"Saved: {img_path}")
|
||||||
|
|
||||||
|
# Prepare the next image buffer
|
||||||
|
buffer.close()
|
||||||
|
img_idx += 1
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# Continue loading data into the current buffer
|
||||||
|
continue
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print("Error processing image:", e)
|
||||||
|
else:
|
||||||
|
print("Failed to generate images.")
|
||||||
|
|
||||||
|
|
||||||
|
# Example usage
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Call the image understanding API
|
||||||
|
understand_image_and_question(image_path, "What is this image about?")
|
||||||
|
|
||||||
|
# Call the image generation API
|
||||||
|
generate_images("A beautiful sunset over a mountain range, digital art.")
|
Loading…
Reference in New Issue
Block a user