Merge pull request #19 from learningpro/main

增加FastAPI Demo
This commit is contained in:
Chong Ruan 2024-10-31 11:26:17 +08:00 committed by GitHub
commit e922b2b4a7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 272 additions and 0 deletions

View File

@ -264,6 +264,22 @@ python demo/app.py
Have Fun!
### FastAPI Demo
It's easy to run a FastAPI server to host an API server running the same functions as gradio.
To start FastAPI server, run the following command:
```
python demo/fastapi_app.py
```
To test the server, you can open another terminal and run:
```
python demo/fastapi_client.py
```
## 5. License
This code repository is licensed under [the MIT License](https://github.com/deepseek-ai/DeepSeek-LLM/blob/HEAD/LICENSE-CODE). The use of Janus models is subject to [DeepSeek Model License](https://github.com/deepseek-ai/DeepSeek-LLM/blob/HEAD/LICENSE-MODEL).

178
demo/fastapi_app.py Normal file
View File

@ -0,0 +1,178 @@
from fastapi import FastAPI, File, Form, UploadFile, HTTPException
from fastapi.responses import JSONResponse, StreamingResponse
import torch
from transformers import AutoConfig, AutoModelForCausalLM
from janus.models import MultiModalityCausalLM, VLChatProcessor
from PIL import Image
import numpy as np
import io
app = FastAPI()
# Load model and processor
model_path = "deepseek-ai/Janus-1.3B"
config = AutoConfig.from_pretrained(model_path)
language_config = config.language_config
language_config._attn_implementation = 'eager'
vl_gpt = AutoModelForCausalLM.from_pretrained(model_path,
language_config=language_config,
trust_remote_code=True)
vl_gpt = vl_gpt.to(torch.bfloat16).cuda()
vl_chat_processor = VLChatProcessor.from_pretrained(model_path)
tokenizer = vl_chat_processor.tokenizer
cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu'
@torch.inference_mode()
def multimodal_understanding(image_data, question, seed, top_p, temperature):
torch.cuda.empty_cache()
torch.manual_seed(seed)
np.random.seed(seed)
torch.cuda.manual_seed(seed)
conversation = [
{
"role": "User",
"content": f"<image_placeholder>\n{question}",
"images": [image_data],
},
{"role": "Assistant", "content": ""},
]
pil_images = [Image.open(io.BytesIO(image_data))]
prepare_inputs = vl_chat_processor(
conversations=conversation, images=pil_images, force_batchify=True
).to(cuda_device, dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16)
inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
outputs = vl_gpt.language_model.generate(
inputs_embeds=inputs_embeds,
attention_mask=prepare_inputs.attention_mask,
pad_token_id=tokenizer.eos_token_id,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
max_new_tokens=512,
do_sample=False if temperature == 0 else True,
use_cache=True,
temperature=temperature,
top_p=top_p,
)
answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
return answer
@app.post("/understand_image_and_question/")
async def understand_image_and_question(
file: UploadFile = File(...),
question: str = Form(...),
seed: int = Form(42),
top_p: float = Form(0.95),
temperature: float = Form(0.1)
):
image_data = await file.read()
response = multimodal_understanding(image_data, question, seed, top_p, temperature)
return JSONResponse({"response": response})
def generate(input_ids,
width,
height,
temperature: float = 1,
parallel_size: int = 5,
cfg_weight: float = 5,
image_token_num_per_image: int = 576,
patch_size: int = 16):
torch.cuda.empty_cache()
tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int).to(cuda_device)
for i in range(parallel_size * 2):
tokens[i, :] = input_ids
if i % 2 != 0:
tokens[i, 1:-1] = vl_chat_processor.pad_id
inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens)
generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).to(cuda_device)
pkv = None
for i in range(image_token_num_per_image):
outputs = vl_gpt.language_model.model(inputs_embeds=inputs_embeds, use_cache=True, past_key_values=pkv)
pkv = outputs.past_key_values
hidden_states = outputs.last_hidden_state
logits = vl_gpt.gen_head(hidden_states[:, -1, :])
logit_cond = logits[0::2, :]
logit_uncond = logits[1::2, :]
logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond)
probs = torch.softmax(logits / temperature, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
generated_tokens[:, i] = next_token.squeeze(dim=-1)
next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)
img_embeds = vl_gpt.prepare_gen_img_embeds(next_token)
inputs_embeds = img_embeds.unsqueeze(dim=1)
patches = vl_gpt.gen_vision_model.decode_code(
generated_tokens.to(dtype=torch.int),
shape=[parallel_size, 8, width // patch_size, height // patch_size]
)
return generated_tokens.to(dtype=torch.int), patches
def unpack(dec, width, height, parallel_size=5):
dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)
dec = np.clip((dec + 1) / 2 * 255, 0, 255)
visual_img = np.zeros((parallel_size, width, height, 3), dtype=np.uint8)
visual_img[:, :, :] = dec
return visual_img
@torch.inference_mode()
def generate_image(prompt, seed, guidance):
torch.cuda.empty_cache()
seed = seed if seed is not None else 12345
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
width = 384
height = 384
parallel_size = 5
with torch.no_grad():
messages = [{'role': 'User', 'content': prompt}, {'role': 'Assistant', 'content': ''}]
text = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(
conversations=messages,
sft_format=vl_chat_processor.sft_format,
system_prompt=''
)
text = text + vl_chat_processor.image_start_tag
input_ids = torch.LongTensor(tokenizer.encode(text))
_, patches = generate(input_ids, width // 16 * 16, height // 16 * 16, cfg_weight=guidance, parallel_size=parallel_size)
images = unpack(patches, width // 16 * 16, height // 16 * 16)
return [Image.fromarray(images[i]).resize((1024, 1024), Image.LANCZOS) for i in range(parallel_size)]
@app.post("/generate_images/")
async def generate_images(
prompt: str = Form(...),
seed: int = Form(None),
guidance: float = Form(5.0),
):
try:
images = generate_image(prompt, seed, guidance)
def image_stream():
for img in images:
buf = io.BytesIO()
img.save(buf, format='PNG')
buf.seek(0)
yield buf.read()
return StreamingResponse(image_stream(), media_type="multipart/related")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Image generation failed: {str(e)}")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)

78
demo/fastapi_client.py Normal file
View File

@ -0,0 +1,78 @@
import requests
from PIL import Image
import io
# Endpoint URLs
understand_image_url = "http://localhost:8000/understand_image_and_question/"
generate_images_url = "http://localhost:8000/generate_images/"
# Use your image file path here
image_path = "images/equation.png"
# Function to call the image understanding endpoint
def understand_image_and_question(image_path, question, seed=42, top_p=0.95, temperature=0.1):
files = {'file': open(image_path, 'rb')}
data = {
'question': question,
'seed': seed,
'top_p': top_p,
'temperature': temperature
}
response = requests.post(understand_image_url, files=files, data=data)
response_data = response.json()
print("Image Understanding Response:", response_data['response'])
# Function to call the text-to-image generation endpoint
def generate_images(prompt, seed=None, guidance=5.0):
data = {
'prompt': prompt,
'seed': seed,
'guidance': guidance
}
response = requests.post(generate_images_url, data=data, stream=True)
if response.ok:
img_idx = 1
# We will create a new BytesIO for each image
buffers = {}
try:
for chunk in response.iter_content(chunk_size=1024):
if chunk:
# Use a boundary detection to determine new image start
if img_idx not in buffers:
buffers[img_idx] = io.BytesIO()
buffers[img_idx].write(chunk)
# Attempt to open the image
try:
buffer = buffers[img_idx]
buffer.seek(0)
image = Image.open(buffer)
img_path = f"generated_image_{img_idx}.png"
image.save(img_path)
print(f"Saved: {img_path}")
# Prepare the next image buffer
buffer.close()
img_idx += 1
except Exception as e:
# Continue loading data into the current buffer
continue
except Exception as e:
print("Error processing image:", e)
else:
print("Failed to generate images.")
# Example usage
if __name__ == "__main__":
# Call the image understanding API
understand_image_and_question(image_path, "What is this image about?")
# Call the image generation API
generate_images("A beautiful sunset over a mountain range, digital art.")