mirror of
https://github.com/deepseek-ai/DeepSeek-Math
synced 2024-11-21 19:17:41 +00:00
83 lines
2.9 KiB
Python
83 lines
2.9 KiB
Python
# Prediction interface for Cog ⚙️
|
|
# https://github.com/replicate/cog/blob/main/docs/python.md
|
|
|
|
import os
|
|
import time
|
|
from threading import Thread
|
|
import torch
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
|
|
from transformers.generation.streamers import TextIteratorStreamer
|
|
from cog import BasePredictor, Input, ConcatenateIterator
|
|
|
|
# Enable faster download speed
|
|
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
|
CACHE_DIR = "model_cache"
|
|
|
|
|
|
class Predictor(BasePredictor):
|
|
def setup(self) -> None:
|
|
"""Load the model into memory to make running multiple predictions efficient"""
|
|
|
|
model_name = "deepseek-ai/deepseek-math-7b-base"
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=CACHE_DIR)
|
|
self.model = AutoModelForCausalLM.from_pretrained(
|
|
model_name,
|
|
torch_dtype=torch.bfloat16,
|
|
device_map="auto",
|
|
cache_dir=CACHE_DIR,
|
|
)
|
|
self.model.generation_config = GenerationConfig.from_pretrained(
|
|
model_name, cache_dir=CACHE_DIR
|
|
)
|
|
self.model.generation_config.pad_token_id = (
|
|
self.model.generation_config.eos_token_id
|
|
)
|
|
|
|
def predict(
|
|
self,
|
|
text: str = Input(
|
|
description="Input text.",
|
|
default="The integral of x^2 from 0 to 2 is",
|
|
),
|
|
max_new_tokens: int = Input(
|
|
description="The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt.",
|
|
default=100,
|
|
),
|
|
temperature: float = Input(
|
|
description="The value used to modulate the next token probabilities.",
|
|
default=1,
|
|
),
|
|
top_k: int = Input(
|
|
description="The number of highest probability vocabulary tokens to keep for top-k-filtering.",
|
|
default=50,
|
|
),
|
|
top_p: float = Input(
|
|
description="If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation.",
|
|
default=0.9,
|
|
),
|
|
) -> ConcatenateIterator[str]:
|
|
"""Run a single prediction on the model"""
|
|
|
|
inputs = self.tokenizer(text, return_tensors="pt")
|
|
streamer = TextIteratorStreamer(
|
|
self.tokenizer, skip_prompt=True, skip_special_tokens=True
|
|
)
|
|
with torch.inference_mode():
|
|
thread = Thread(
|
|
target=self.model.generate,
|
|
kwargs=dict(
|
|
**inputs.to(self.model.device),
|
|
do_sample=True,
|
|
temperature=temperature,
|
|
top_p=top_p,
|
|
top_k=top_k,
|
|
max_new_tokens=max_new_tokens,
|
|
streamer=streamer,
|
|
use_cache=True
|
|
),
|
|
)
|
|
thread.start()
|
|
for new_token in streamer:
|
|
yield new_token
|
|
thread.join()
|