mirror of
https://github.com/deepseek-ai/DeepSeek-Coder
synced 2025-01-22 18:45:34 +00:00
146 lines
4.3 KiB
Python
146 lines
4.3 KiB
Python
import re
|
|
|
|
languge_settings = {
|
|
'python': {
|
|
'full_name': 'Python',
|
|
'indent': 4,
|
|
},
|
|
'cpp': {
|
|
'full_name': 'cpp',
|
|
'indent': 0,
|
|
'main': "int main()",
|
|
},
|
|
'java': {
|
|
'full_name': 'Java',
|
|
'indent': 4,
|
|
'main': "public static void main",
|
|
},
|
|
'cs': {
|
|
'full_name': "csharp",
|
|
'indent': 0,
|
|
'main': "public static void Main",
|
|
},
|
|
'php': {
|
|
'full_name': "PHP",
|
|
'indent': 0,
|
|
},
|
|
'ts': {
|
|
'full_name': "TypeScript",
|
|
'indent': 0,
|
|
},
|
|
'js': {
|
|
'full_name': "JavaScript",
|
|
'indent': 0
|
|
},
|
|
'sh': {
|
|
'full_name': "Bash",
|
|
'indent': 0
|
|
}
|
|
}
|
|
|
|
def get_function_name(question: str, lang: str):
|
|
func_lines = [x for x in question.strip().split('\n') if x.strip()]
|
|
|
|
if lang.lower() == 'python':
|
|
func_idx = [i for i in range(len(func_lines)) if func_lines[i].startswith("def ")][-1]
|
|
func_name = func_lines[func_idx].split('(')[0].strip()
|
|
func_prefix = "\n".join(func_lines[:func_idx])
|
|
return func_name, func_prefix
|
|
|
|
func_name = func_lines[-1].split('{')[0].strip()
|
|
func_prefix = "\n".join(func_lines[:-1])
|
|
return func_name, func_prefix
|
|
|
|
def extract_generation_code(example: str, lang_code: str, verbose: bool=False):
|
|
task_id = example['task_id']
|
|
output = example.get('output', example.get("gpt_completion"))
|
|
question = example["prompt"].strip()
|
|
setting = languge_settings[lang_code]
|
|
lang = setting['full_name']
|
|
indent = setting['indent']
|
|
|
|
try:
|
|
code_block: str = re.findall(f'```{lang.lower()}\n(.*?)```', output, re.DOTALL | re.IGNORECASE)[0]
|
|
if verbose:
|
|
print(">>> Task: {}\n{}".format(task_id, code_block))
|
|
|
|
# Remove main
|
|
if setting.get('main', None) and setting['main'] in code_block:
|
|
main_start = code_block.index(setting['main'])
|
|
code_block = code_block[:main_start]
|
|
|
|
func_name, func_prefix = get_function_name(question, lang)
|
|
|
|
try:
|
|
start = code_block.lower().index(func_name.lower())
|
|
indent = 0
|
|
while start - indent >= 0 and code_block[start - indent-1] == ' ':
|
|
indent += 1
|
|
|
|
try:
|
|
end = code_block.rindex('\n' + ' '*indent + '}')
|
|
except:
|
|
end = len(code_block)
|
|
except:
|
|
start = 0
|
|
try:
|
|
end = code_block.rindex('\n' + ' '*indent + '}')
|
|
except:
|
|
end = len(code_block)
|
|
|
|
body = code_block[start:end]
|
|
|
|
if lang_code.lower() in ['php', 'ts', 'js']:
|
|
body += '\n' + ' '*indent + '}'
|
|
|
|
generation = func_prefix + '\n' + body + '\n'
|
|
example['generation'] = generation
|
|
|
|
except Exception as ex:
|
|
print("Failed to extract code block with error `{}`:\n>>> Task: {}\n>>> Output:\n{}".format(
|
|
ex, task_id, output
|
|
))
|
|
example['generation'] = example['prompt'] + '\n' + output
|
|
|
|
return example
|
|
|
|
def cleanup_code(
|
|
code: str,
|
|
language_type: str = None,
|
|
dataset: str = None,
|
|
issft: bool = False,
|
|
stop_words = []
|
|
):
|
|
"""
|
|
Cleans up the generated code.
|
|
"""
|
|
|
|
if language_type.lower() == "python":
|
|
if issft:
|
|
code = _clean_python_code_for_sft(code)
|
|
stop_words = ["\ndef", "\nclass", "\nif", "\n#", "\nprint"]
|
|
code = _truncate_code_at_stopwords(code, stop_words)
|
|
elif language_type.lower() == "ts":
|
|
code = _truncate_code_at_stopwords(code, stop_words + ["\nexport", "\nimport", "\nexport default", "\nimport default", "\nconsole.log"])
|
|
else:
|
|
code = _truncate_code_at_stopwords(code, stop_words)
|
|
|
|
return code
|
|
|
|
def _clean_python_code_for_sft(code):
|
|
code = code.replace("\r", "")
|
|
if "```python" in code:
|
|
code_start_idx = code.index("```python")
|
|
code = code[code_start_idx:].replace("```python", "").strip()
|
|
end_idx = code.find("```") if "```" in code else len(code)
|
|
code = code[:end_idx].strip()
|
|
|
|
return code
|
|
|
|
def _truncate_code_at_stopwords(code, stop_words):
|
|
min_stop_idx = len(code)
|
|
for stop_word in stop_words:
|
|
stop_index = code.find(stop_word)
|
|
if 0 <= stop_index < min_stop_idx:
|
|
min_stop_idx = stop_index
|
|
return code[:min_stop_idx] |