Fix injected task import in Task.populate()

This commit is contained in:
allegroai 2024-07-26 19:13:34 +03:00
parent 4417812410
commit d826e9806e

View File

@ -422,8 +422,9 @@ class CreateAndPopulate(object):
"diff --git a{script_entry} b{script_entry}\n" \
"--- a{script_entry}\n" \
"+++ b{script_entry}\n" \
"@@ -{idx_a},0 +{idx_b},3 @@\n" \
"+from clearml import Task\n" \
"@@ -{idx_a},0 +{idx_b},4 @@\n" \
"+try: from allegroai import Task\n" \
"+except ImportError: from clearml import Task\n" \
"+(__name__ != \"__main__\") or Task.init()\n" \
"+\n".format(
script_entry=script_entry, idx_a=idx_a, idx_b=idx_a + 1)
@ -432,7 +433,11 @@ class CreateAndPopulate(object):
pass
elif local_entry_file and lines:
# if we are here it means we do not have a git diff, but a single script file
init_lines = ["from clearml import Task\n", "(__name__ != \"__main__\") or Task.init()\n\n"]
init_lines = [
"try: from allegroai import Task\n",
"except ImportError: from clearml import Task\n",
'(__name__ != "__main__") or Task.init()\n\n',
]
task_state['script']['diff'] = ''.join(lines[:idx_a] + init_lines + lines[idx_a:])
# no need to add anything, we patched it.
task_init_patch = ""
@ -440,7 +445,8 @@ class CreateAndPopulate(object):
# Add Task.init call
# if we are here it means we do not have a git diff, but a single script file
task_init_patch += \
"from clearml import Task\n" \
"try: from allegroai import Task\n" \
"except ImportError: from clearml import Task\n" \
"(__name__ != \"__main__\") or Task.init()\n\n"
task_state['script']['diff'] = task_init_patch + task_state['script'].get('diff', '')
task_init_patch = ""