import re languge_settings = { 'python': { 'full_name': 'Python', 'indent': 4, }, 'cpp': { 'full_name': 'cpp', 'indent': 0, 'main': "int main()", }, 'java': { 'full_name': 'Java', 'indent': 4, 'main': "public static void main", }, 'cs': { 'full_name': "csharp", 'indent': 0, 'main': "public static void Main", }, 'php': { 'full_name': "PHP", 'indent': 0, }, 'ts': { 'full_name': "TypeScript", 'indent': 0, }, 'js': { 'full_name': "JavaScript", 'indent': 0 }, 'sh': { 'full_name': "Bash", 'indent': 0 } } def get_function_name(question: str, lang: str): func_lines = [x for x in question.strip().split('\n') if x.strip()] if lang.lower() == 'python': func_idx = [i for i in range(len(func_lines)) if func_lines[i].startswith("def ")][-1] func_name = func_lines[func_idx].split('(')[0].strip() func_prefix = "\n".join(func_lines[:func_idx]) return func_name, func_prefix func_name = func_lines[-1].split('{')[0].strip() func_prefix = "\n".join(func_lines[:-1]) return func_name, func_prefix def extract_generation_code(example: str, lang_code: str, verbose: bool=False): task_id = example['task_id'] output = example.get('output', example.get("gpt_completion")) question = example["prompt"].strip() setting = languge_settings[lang_code] lang = setting['full_name'] indent = setting['indent'] try: code_block: str = re.findall(f'```{lang.lower()}\n(.*?)```', output, re.DOTALL | re.IGNORECASE)[0] if verbose: print(">>> Task: {}\n{}".format(task_id, code_block)) # Remove main if setting.get('main', None) and setting['main'] in code_block: main_start = code_block.index(setting['main']) code_block = code_block[:main_start] func_name, func_prefix = get_function_name(question, lang) try: start = code_block.lower().index(func_name.lower()) indent = 0 while start - indent >= 0 and code_block[start - indent-1] == ' ': indent += 1 try: end = code_block.rindex('\n' + ' '*indent + '}') except: end = len(code_block) except: start = 0 try: end = code_block.rindex('\n' + ' '*indent + '}') except: end = len(code_block) body = code_block[start:end] if lang_code.lower() in ['php', 'ts', 'js']: body += '\n' + ' '*indent + '}' generation = func_prefix + '\n' + body + '\n' example['generation'] = generation except Exception as ex: print("Failed to extract code block with error `{}`:\n>>> Task: {}\n>>> Output:\n{}".format( ex, task_id, output )) example['generation'] = example['prompt'] + '\n' + output return example def cleanup_code( code: str, language_type: str = None, dataset: str = None, issft: bool = False, stop_words = [] ): """ Cleans up the generated code. """ if language_type.lower() == "python": if issft: code = _clean_python_code_for_sft(code) stop_words = ["\ndef", "\nclass", "\nif", "\n#", "\nprint"] code = _truncate_code_at_stopwords(code, stop_words) elif language_type.lower() == "ts": code = _truncate_code_at_stopwords(code, stop_words + ["\nexport", "\nimport", "\nexport default", "\nimport default", "\nconsole.log"]) else: code = _truncate_code_at_stopwords(code, stop_words) return code def _clean_python_code_for_sft(code): code = code.replace("\r", "") if "```python" in code: code_start_idx = code.index("```python") code = code[code_start_idx:].replace("```python", "").strip() end_idx = code.find("```") if "```" in code else len(code) code = code[:end_idx].strip() return code def _truncate_code_at_stopwords(code, stop_words): min_stop_idx = len(code) for stop_word in stop_words: stop_index = code.find(stop_word) if 0 <= stop_index < min_stop_idx: min_stop_idx = stop_index return code[:min_stop_idx]