fix test file

This commit is contained in:
Yang Dejian 2023-11-23 16:46:17 +08:00
parent d40caedfdf
commit 94c476e295

View File

@ -71,10 +71,9 @@ def generate_one(example, tokenizer, model):
stop_id = tokenizer.convert_tokens_to_ids("<|EOT|>") stop_id = tokenizer.convert_tokens_to_ids("<|EOT|>")
assert isinstance(stop_id, int), "Invalid tokenizer, EOT id not found" assert isinstance(stop_id, int), "Invalid tokenizer, EOT id not found"
outputs = model.generate( outputs = model.generate(
inputs, inputs,
max_new_tokens=1024, max_new_tokens=512,
do_sample=False, do_sample=False,
# top_p=0.95, # top_p=0.95,
# temperature=temperature, # temperature=temperature,
@ -83,8 +82,8 @@ def generate_one(example, tokenizer, model):
) )
output = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True) output = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True)
# print(output)
example['gpt_completion'] = output example['gpt_completion'] = output
return convert_for_evaluation(example) return convert_for_evaluation(example)
def generate_main(args): def generate_main(args):
@ -122,7 +121,7 @@ def generate_main(args):
result = evaluate_functional_correctness( result = evaluate_functional_correctness(
input_file=saved_path, input_file=saved_path,
tmp_dir=temp_dir, tmp_dir=temp_dir,
problem_file=problem_file, problem_file=os.path.join(data_abs_dir, f"mbpp_test.jsonl"),
language='python', language='python',
is_mbpp=True is_mbpp=True
) )