mirror of
https://github.com/deepseek-ai/DeepSeek-Coder
synced 2025-01-23 10:57:05 +00:00
62 lines
2.2 KiB
Python
62 lines
2.2 KiB
Python
|
import os
|
||
|
import numpy as np
|
||
|
import json
|
||
|
|
||
|
class HumanEvalDataset:
|
||
|
|
||
|
def __init__(self, root, sample_num=1, language="python", issft=False):
|
||
|
"""
|
||
|
root: the path to the HumanEval dataset
|
||
|
sample_num: the number of samples for each prompt
|
||
|
language: the language of the HumanEval dataset
|
||
|
issft: whether to use the SFT setting
|
||
|
"""
|
||
|
self.root = root
|
||
|
self.data = open(os.path.join(self.root, f"humaneval-{language}.jsonl")).readlines()
|
||
|
|
||
|
tmp = self.get_qa_only_data(self.data, issft)
|
||
|
self.clean_data = []
|
||
|
for i in range(len(tmp)):
|
||
|
for j in range(sample_num):
|
||
|
self.clean_data.append(tmp[i])
|
||
|
self.stopwords = self.clean_data[0]["stopwords"]
|
||
|
np.random.seed(1234)
|
||
|
print(f"Read HumanEval from {root}, number of samples {len(self.clean_data)}")
|
||
|
|
||
|
def get_qa_only_data(self, data_json, sft=False):
|
||
|
"""
|
||
|
data_json: the jsonl file of HumanEval
|
||
|
sft: whether to use the SFT setting
|
||
|
return: a list of dict, each dict contains the prompt, task_id and stopwords
|
||
|
"""
|
||
|
ans = []
|
||
|
for line in data_json:
|
||
|
line = json.loads(line)
|
||
|
prompt = line["prompt"].strip()
|
||
|
if "prefix" in line:
|
||
|
origin_prompt = line["prefix"]
|
||
|
else:
|
||
|
origin_prompt = line["prompt"]
|
||
|
|
||
|
if sft:
|
||
|
prompt = f"""Below is an instruction that describes a task, paired with an input that provides further context.\nWrite a response that appropriately completes the request.\n\n### Instruction:\nWrite a program to perform the given task.\n\nInput:\n{prompt}\n\n### Response:\n"""
|
||
|
if "stop_tokens" in line:
|
||
|
s = line["stop_tokens"]
|
||
|
else:
|
||
|
s = []
|
||
|
ans.append({"prompt":prompt, "task_id":line["task_id"], "original_prompt": origin_prompt, "stopwords":s})
|
||
|
return ans
|
||
|
|
||
|
def __len__(self):
|
||
|
"""
|
||
|
return the number of samples in the dataset
|
||
|
"""
|
||
|
return len(self.clean_data)
|
||
|
|
||
|
def __getitem__(self, index):
|
||
|
"""
|
||
|
return the sample at index
|
||
|
"""
|
||
|
sample = self.clean_data[index]
|
||
|
return sample
|