DeepSeek-Math/predict.py

83 lines
2.9 KiB
Python
Raw Normal View History

2024-02-11 14:45:44 +00:00
# Prediction interface for Cog ⚙️
# https://github.com/replicate/cog/blob/main/docs/python.md
import os
import time
from threading import Thread
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
from transformers.generation.streamers import TextIteratorStreamer
from cog import BasePredictor, Input, ConcatenateIterator
# Enable faster download speed
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
CACHE_DIR = "model_cache"
class Predictor(BasePredictor):
def setup(self) -> None:
"""Load the model into memory to make running multiple predictions efficient"""
model_name = "deepseek-ai/deepseek-math-7b-base"
self.tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=CACHE_DIR)
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map="auto",
cache_dir=CACHE_DIR,
)
self.model.generation_config = GenerationConfig.from_pretrained(
model_name, cache_dir=CACHE_DIR
)
self.model.generation_config.pad_token_id = (
self.model.generation_config.eos_token_id
)
def predict(
self,
text: str = Input(
description="Input text.",
default="The integral of x^2 from 0 to 2 is",
),
max_new_tokens: int = Input(
description="The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt.",
default=100,
),
temperature: float = Input(
description="The value used to modulate the next token probabilities.",
default=1,
),
top_k: int = Input(
description="The number of highest probability vocabulary tokens to keep for top-k-filtering.",
default=50,
),
top_p: float = Input(
description="If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation.",
default=0.9,
),
) -> ConcatenateIterator[str]:
"""Run a single prediction on the model"""
inputs = self.tokenizer(text, return_tensors="pt")
streamer = TextIteratorStreamer(
self.tokenizer, skip_prompt=True, skip_special_tokens=True
)
with torch.inference_mode():
thread = Thread(
target=self.model.generate,
kwargs=dict(
**inputs.to(self.model.device),
do_sample=True,
temperature=temperature,
top_p=top_p,
top_k=top_k,
max_new_tokens=max_new_tokens,
streamer=streamer,
use_cache=True
),
)
thread.start()
for new_token in streamer:
yield new_token
thread.join()