Merge pull request #93 from pcystc/main

Fix the position of add_generation_prompt
This commit is contained in:
Daya Guo 2024-01-11 01:07:27 +08:00 committed by GitHub
commit bda040aa34
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -23,7 +23,8 @@ def generate_one(example, lang, tokenizer, model):
prompt = build_deepseekcoder_instruction(languge_settings[lang]['full_name'], example['prompt']) prompt = build_deepseekcoder_instruction(languge_settings[lang]['full_name'], example['prompt'])
inputs = tokenizer.apply_chat_template( inputs = tokenizer.apply_chat_template(
[{'role': 'user', 'content': prompt }], [{'role': 'user', 'content': prompt }],
return_tensors="pt" return_tensors="pt",
add_generation_prompt=True
).to(model.device) ).to(model.device)
stop_id = tokenizer.convert_tokens_to_ids("<|EOT|>") stop_id = tokenizer.convert_tokens_to_ids("<|EOT|>")
@ -39,7 +40,7 @@ def generate_one(example, lang, tokenizer, model):
eos_token_id=stop_id eos_token_id=stop_id
) )
output = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True, add_generation_prompt=True) output = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True)
example['output'] = output example['output'] = output
return extract_generation_code(example, lang_code=lang) return extract_generation_code(example, lang_code=lang)