Add output streaming

This commit is contained in:
Luis 2024-03-11 19:52:55 +00:00
parent 53c540ec9a
commit 86b6e851f8

View File

@ -1,12 +1,12 @@
# Prediction interface for Cog ⚙️
# https://github.com/replicate/cog/blob/main/docs/python.md
from cog import BasePredictor, Input, Path
from cog import BasePredictor, Input, Path, ConcatenateIterator
import os
import torch
from threading import Thread
from transformers import AutoModelForCausalLM
from deepseek_vl.utils.io import load_pil_images
from transformers import AutoModelForCausalLM, TextIteratorStreamer
from deepseek_vl.models import VLChatProcessor, MultiModalityCausalLM
# Enable faster download speed
@ -34,9 +34,9 @@ class Predictor(BasePredictor):
def predict(
self,
image: Path = Input(description="Input image"),
prompt: str = Input(description="Input prompt", default="Describe the image"),
prompt: str = Input(description="Input prompt", default="Describe this image"),
max_new_tokens: int = Input(description="Maximum number of tokens to generate", default=512)
) -> str:
) -> ConcatenateIterator[str]:
"""Run a single prediction on the model"""
conversation = [
{
@ -57,21 +57,26 @@ class Predictor(BasePredictor):
images=pil_images,
force_batchify=True
).to('cuda')
# run image encoder to get the image embeddings
inputs_embeds = self.vl_gpt.prepare_inputs_embeds(**prepare_inputs)
# run the model to get the response
outputs = self.vl_gpt.language_model.generate(
inputs_embeds=inputs_embeds,
attention_mask=prepare_inputs.attention_mask,
pad_token_id=self.tokenizer.eos_token_id,
bos_token_id=self.tokenizer.bos_token_id,
eos_token_id=self.tokenizer.eos_token_id,
max_new_tokens=max_new_tokens,
do_sample=False,
use_cache=True
streamer = TextIteratorStreamer(
self.tokenizer, skip_prompt=True, skip_special_tokens=True
)
answer = self.tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
return answer
thread = Thread(
target=self.vl_gpt.language_model.generate,
kwargs={
"inputs_embeds": self.vl_gpt.prepare_inputs_embeds(**prepare_inputs),
"attention_mask": prepare_inputs.attention_mask,
"pad_token_id": self.tokenizer.eos_token_id,
"bos_token_id": self.tokenizer.bos_token_id,
"eos_token_id": self.tokenizer.eos_token_id,
"max_new_tokens": max_new_tokens,
"do_sample": False,
"use_cache": True,
"streamer": streamer,
},
)
thread.start()
for new_token in streamer:
yield new_token
thread.join()