init project

This commit is contained in:
Dejian
2023-11-02 22:07:09 +08:00
commit a4ba628dfd
111 changed files with 26064 additions and 0 deletions

Binary file not shown.

View File

@@ -0,0 +1,48 @@
import os
import numpy as np
import json
class MBPPDataset:
def __init__(self, root, samplenum=1):
"""
root: 数据文件的根目录
"""
self.root = root
self.data = open(os.path.join(root, "mbpp.jsonl")).readlines()
self.clean_data = self.get_qa_only_data(self.data)
self.prompt = []
for i in range(1, 4):
prompt = self.clean_data[i]["prompt"]
tests = "\n".join(self.clean_data[i]["test"])
code = self.clean_data[i]["code"].replace("\r", "").replace("\t", " ")
prompt1 = f"You are an expert Python programmer, and here is your task: {prompt} Your code should pass these tests:\n\n{tests}\n[BEGIN]\n{code}\n[DONE]\n"
if len(self.prompt) == 0:
self.prompt.append(prompt1)
else:
self.prompt.append(self.prompt[-1] + prompt1)
self.testdata = []
for i in range(10, 510):
for j in range(samplenum):
self.testdata.append(self.clean_data[i])
np.random.seed(1234)
print(f"Read MBPP from {root}, number of samples {len(self.testdata)}")
def get_qa_only_data(self, data_json):
ans = []
for line in data_json:
line = json.loads(line)
prompt = line["text"]
suffix = line["test_list"]
code = line["code"]
ans.append({"prompt":prompt, "test":suffix, "code":code, "task_id":line["task_id"]})
return ans
def __len__(self):
return len(self.testdata)
def __getitem__(self, index):
sample = self.testdata[index]
return sample

View File

@@ -0,0 +1,40 @@
def cleanup_code(
code: str,
language_type: str = None,
dataset: str = None,
issft: bool = False,
stop_words = []
):
"""
Cleans up the generated code.
"""
if language_type.lower() == "python":
if issft:
code = _clean_python_code_for_sft(code)
stop_words = ["\ndef", "\nclass", "\nif", "\n#", "\nprint"]
code = _truncate_code_at_stopwords(code, stop_words)
elif language_type.lower() == "ts":
code = _truncate_code_at_stopwords(code, stop_words + ["\nexport", "\nimport", "\nexport default", "\nimport default", "\nconsole.log"])
else:
code = _truncate_code_at_stopwords(code, stop_words)
return code
def _clean_python_code_for_sft(code):
code = code.replace("\r", "")
if "```python" in code:
code_start_idx = code.index("```python")
code = code[code_start_idx:].replace("```python", "").strip()
end_idx = code.find("```") if "```" in code else len(code)
code = code[:end_idx].strip()
return code
def _truncate_code_at_stopwords(code, stop_words):
min_stop_idx = len(code)
for stop_word in stop_words:
stop_index = code.find(stop_word)
if 0 <= stop_index < min_stop_idx:
min_stop_idx = stop_index
return code[:min_stop_idx]