mirror of
https://github.com/deepseek-ai/DeepSeek-Math
synced 2024-11-22 11:38:17 +00:00
264 lines
12 KiB
Python
Executable File
264 lines
12 KiB
Python
Executable File
import torch
|
|
import tqdm
|
|
from transformers import StoppingCriteria, GenerationConfig
|
|
|
|
class KeyWordsCriteria(StoppingCriteria):
|
|
def __init__(self, stop_id_sequences, tokenizer, prompt_length):
|
|
assert isinstance(stop_id_sequences[0], list), "stop_id_sequences should be a list of list of ids"
|
|
self.tokenizer = tokenizer
|
|
self.stop_id_sequences = stop_id_sequences
|
|
self.stop_sequences = [tokenizer.decode(sequence) for sequence in stop_id_sequences]
|
|
print(f"stop sequences: {self.stop_sequences}", flush=True)
|
|
self.prompt_length = prompt_length
|
|
|
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
|
sequences_should_be_stopped = []
|
|
for i in range(input_ids.shape[0]):
|
|
ids = input_ids[i][self.prompt_length:].tolist()
|
|
should_be_stopped = False
|
|
for stop_ids, stop_sequence in zip(self.stop_id_sequences, self.stop_sequences):
|
|
_ids = ids
|
|
for j in range(len(_ids), 0, -1):
|
|
s = self.tokenizer.decode(_ids[max(j - len(stop_ids) - 3, 0) :j])
|
|
if s.endswith(stop_sequence):
|
|
should_be_stopped = True
|
|
break
|
|
if should_be_stopped:
|
|
break
|
|
sequences_should_be_stopped.append(should_be_stopped)
|
|
return all(sequences_should_be_stopped)
|
|
|
|
@torch.no_grad()
|
|
def generate_completions(model, tokenizer, prompts, batch_size=1, stop_id_sequences=None, end_of_generation_id_sequence=None, disable_tqdm=False, **generation_kwargs):
|
|
generations = []
|
|
finish_completion = []
|
|
if not disable_tqdm:
|
|
progress = tqdm.tqdm(total=len(prompts), desc="Generating Completions")
|
|
|
|
if stop_id_sequences is not None:
|
|
stop_sequences = [tokenizer.decode(stop_id_sequence) for stop_id_sequence in stop_id_sequences]
|
|
|
|
if end_of_generation_id_sequence is not None:
|
|
end_of_generation_sequence = tokenizer.decode(end_of_generation_id_sequence)
|
|
|
|
num_return_sequences = generation_kwargs.get("num_return_sequences", 1)
|
|
generation_kwargs['use_cache'] = True
|
|
for i in range(0, len(prompts), batch_size):
|
|
batch_prompts = prompts[i:i+batch_size]
|
|
tokenized_prompts = tokenizer(batch_prompts, padding="longest", return_tensors="pt", add_special_tokens='chatglm2' in str(model.__class__))
|
|
batch_input_ids = tokenized_prompts.input_ids
|
|
attention_mask = tokenized_prompts.attention_mask
|
|
|
|
if model.device.type == "cuda":
|
|
batch_input_ids = batch_input_ids.cuda()
|
|
attention_mask = attention_mask.cuda()
|
|
|
|
batch_finish_completion = [False] * len(batch_prompts) * num_return_sequences
|
|
try:
|
|
batch_outputs = model.generate(
|
|
input_ids=batch_input_ids,
|
|
attention_mask=attention_mask,
|
|
stopping_criteria=[KeyWordsCriteria(stop_id_sequences, tokenizer, batch_input_ids.size(1))] if stop_id_sequences else None,
|
|
**generation_kwargs
|
|
)
|
|
|
|
# the stopping criteria is applied at batch level, so if other examples are not stopped, the entire batch will continue to generate.
|
|
# so some outputs still have the stop sequence, which we need to remove.
|
|
if stop_id_sequences:
|
|
for output_idx in range(batch_outputs.shape[0]):
|
|
finish = False
|
|
for token_idx in range(batch_input_ids.shape[1], batch_outputs.shape[1]):
|
|
if any(tokenizer.decode(batch_outputs[output_idx, token_idx: token_idx + len(stop_sequence) + 3]).startswith(stop_sequence) for stop_sequence in stop_sequences):
|
|
if end_of_generation_id_sequence is not None and tokenizer.decode(batch_outputs[output_idx, token_idx: token_idx + len(end_of_generation_id_sequence) + 3]).startswith(end_of_generation_sequence):
|
|
batch_finish_completion[output_idx] = True
|
|
batch_outputs[output_idx, token_idx:] = tokenizer.pad_token_id
|
|
break
|
|
|
|
# remove the prompt from the output
|
|
# we need to re-encode the prompt because we need to make sure the special tokens are treated the same way as in the outputs.
|
|
# we changed our previous way of truncating the output token ids dicrectly because some tokenizer (e.g., llama) won't add space token before the first token.
|
|
# space is important for some tasks (e.g., code completion).
|
|
batch_outputs = tokenizer.batch_decode(batch_outputs, skip_special_tokens=True)
|
|
batch_prompts = tokenizer.batch_decode(batch_input_ids, skip_special_tokens=True)
|
|
# duplicate the prompts to match the number of return sequences
|
|
batch_prompts = [prompt for prompt in batch_prompts for _ in range(num_return_sequences)]
|
|
batch_generations = [
|
|
output[len(prompt):] for prompt, output in zip(batch_prompts, batch_outputs)
|
|
]
|
|
except Exception as e:
|
|
print("Error when generating completions for batch:")
|
|
print(batch_prompts)
|
|
print("Error message:")
|
|
print(e)
|
|
print("Use empty string as the completion.")
|
|
batch_generations = [""] * len(batch_prompts) * num_return_sequences
|
|
|
|
generations += batch_generations
|
|
finish_completion += batch_finish_completion
|
|
|
|
if not disable_tqdm:
|
|
progress.update(len(batch_prompts)//num_return_sequences)
|
|
|
|
assert len(generations) == len(prompts) * num_return_sequences, "number of generations should be equal to number of prompts * num_return_sequences"
|
|
return generations, finish_completion
|
|
|
|
|
|
@torch.no_grad()
|
|
def get_next_word_predictions(model, tokenizer, prompts, candidate_token_ids=None, batch_size=1, return_token_predictions=False, disable_tqdm=False):
|
|
predictions, probs = [], []
|
|
if not disable_tqdm:
|
|
progress = tqdm.tqdm(total=len(prompts), desc="Getting Predictions")
|
|
|
|
for i in range(0, len(prompts), batch_size):
|
|
batch_prompts = prompts[i: i+batch_size]
|
|
tokenized_prompts = tokenizer(batch_prompts, padding="longest", return_tensors="pt", add_special_tokens=False)
|
|
batch_input_ids = tokenized_prompts.input_ids
|
|
attention_mask = tokenized_prompts.attention_mask
|
|
|
|
if model.device.type == "cuda":
|
|
batch_input_ids = batch_input_ids.cuda()
|
|
attention_mask = attention_mask.cuda()
|
|
|
|
batch_logits = model(input_ids=batch_input_ids, attention_mask=attention_mask).logits[:, -1, :]
|
|
if candidate_token_ids is not None:
|
|
batch_logits = batch_logits[:, candidate_token_ids]
|
|
batch_probs = torch.softmax(batch_logits, dim=-1)
|
|
batch_prediction_indices = torch.argmax(batch_probs, dim=-1)
|
|
if return_token_predictions:
|
|
if candidate_token_ids is not None:
|
|
candidate_tokens = tokenizer.convert_ids_to_tokens(candidate_token_ids)
|
|
batch_predictions = [candidate_tokens[idx] for idx in batch_prediction_indices]
|
|
else:
|
|
batch_predictions = tokenizer.convert_ids_to_tokens(batch_prediction_indices)
|
|
predictions += batch_predictions
|
|
else:
|
|
predictions += batch_prediction_indices.tolist()
|
|
probs += batch_probs.tolist()
|
|
|
|
if not disable_tqdm:
|
|
progress.update(len(batch_prompts))
|
|
|
|
assert len(predictions) == len(prompts), "number of predictions should be equal to number of prompts"
|
|
return predictions, probs
|
|
|
|
|
|
@torch.no_grad()
|
|
def score_completions(model, tokenizer, scoring_examples, disable_tqdm=False):
|
|
'''
|
|
Each scoring example is a dict, which contains the following keys:
|
|
- prompt: the prompt to score
|
|
- completions: a list of completions to score
|
|
'''
|
|
|
|
if not disable_tqdm:
|
|
progress = tqdm.tqdm(total=len(scoring_examples), desc="Scoring Completions")
|
|
|
|
# unroll the scoring examples
|
|
unrolled_examples = []
|
|
for scoring_example in scoring_examples:
|
|
prompt = scoring_example["prompt"]
|
|
for completion in scoring_example["completions"]:
|
|
unrolled_examples.append({
|
|
"prompt": prompt,
|
|
"completion": completion
|
|
})
|
|
|
|
scores = []
|
|
# currently we don't support batching, because we want to directly use the loss returned by the model to score each completion.
|
|
for unrolled_example in unrolled_examples:
|
|
encoded_example = encode_with_prompt_completion_format(unrolled_example, tokenizer, max_seq_length=None)
|
|
# unsqueeze the batch dimension
|
|
for key, value in encoded_example.items():
|
|
encoded_example[key] = value.unsqueeze(0)
|
|
if model.device.type == "cuda":
|
|
encoded_example = {
|
|
key: value.cuda() for key, value in encoded_example.items()
|
|
}
|
|
outputs = model(**encoded_example)
|
|
loss = outputs.loss
|
|
scores.append(-loss.item())
|
|
if not disable_tqdm:
|
|
progress.update(1)
|
|
|
|
# roll up the scores
|
|
rolled_up_scores = {}
|
|
for unrolled_example, score in zip(unrolled_examples, scores):
|
|
prompt = unrolled_example["prompt"]
|
|
completion = unrolled_example["completion"]
|
|
if prompt not in rolled_up_scores:
|
|
rolled_up_scores[prompt] = {}
|
|
rolled_up_scores[prompt][completion] = score
|
|
|
|
return rolled_up_scores
|
|
|
|
|
|
|
|
def load_hf_lm_and_tokenizer(
|
|
model_name_or_path,
|
|
tokenizer_name_or_path=None,
|
|
device_map="auto",
|
|
load_in_8bit=False,
|
|
load_in_half=False,
|
|
gptq_model=False,
|
|
use_fast_tokenizer=True,
|
|
padding_side="left",
|
|
):
|
|
|
|
from transformers import AutoModelForCausalLM, AutoModel, AutoTokenizer
|
|
|
|
if not tokenizer_name_or_path:
|
|
tokenizer_name_or_path = model_name_or_path
|
|
|
|
is_chatglm2 = 'chatglm2' in tokenizer_name_or_path.lower() or 'chatglm2' in model_name_or_path
|
|
is_qwen = 'qwen' in tokenizer_name_or_path.lower() or 'qwen' in model_name_or_path
|
|
|
|
if is_chatglm2 or is_qwen:
|
|
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, trust_remote_code=True)
|
|
if is_qwen:
|
|
tokenizer.eos_token = '<|endoftext|>'
|
|
tokenizer.eos_token_id = 151643
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
tokenizer.pad_token_id = tokenizer.eos_token_id
|
|
else:
|
|
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, trust_remote_code=True, use_fast=use_fast_tokenizer)
|
|
# set padding side to left for batch generation
|
|
tokenizer.padding_side = padding_side
|
|
# set pad token to eos token if pad token is not set (as is the case for llama models)
|
|
if tokenizer.pad_token is None:
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
tokenizer.pad_token_id = tokenizer.eos_token_id
|
|
|
|
if gptq_model:
|
|
from auto_gptq import AutoGPTQForCausalLM
|
|
model_wrapper = AutoGPTQForCausalLM.from_quantized(
|
|
model_name_or_path, device="cuda:0", use_triton=True
|
|
)
|
|
model = model_wrapper.model
|
|
elif load_in_8bit:
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
model_name_or_path,
|
|
device_map=device_map,
|
|
load_in_8bit=True
|
|
)
|
|
else:
|
|
kwargs = {}
|
|
model_class = AutoModelForCausalLM
|
|
if is_chatglm2:
|
|
kwargs = {'trust_remote_code': True}
|
|
model_class = AutoModel
|
|
elif is_qwen:
|
|
kwargs = {'trust_remote_code': True}
|
|
if device_map:
|
|
model = model_class.from_pretrained(model_name_or_path, device_map=device_map, **kwargs)
|
|
else:
|
|
model = model_class.from_pretrained(model_name_or_path, **kwargs)
|
|
if torch.cuda.is_available():
|
|
model = model.cuda()
|
|
if is_qwen:
|
|
model.generation_config = GenerationConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
|
|
model.generation_config.do_sample = False
|
|
if not is_chatglm2 and not is_qwen and load_in_half:
|
|
model = model.half()
|
|
model.eval()
|
|
return model, tokenizer
|