Fix "from clearml" runtime diff patching (make sure we move it to after all the __future__ imports) include handling triple quotes in comments

This commit is contained in:
allegroai 2021-02-11 14:46:06 +02:00
parent 296f7970df
commit 784c676f5b
2 changed files with 80 additions and 3 deletions

View File

@ -92,7 +92,7 @@ from clearml_agent.helper.process import (
commit_docker, terminate_process,
)
from clearml_agent.helper.package.priority_req import PriorityPackageRequirement, PackageCollectorRequirement
from clearml_agent.helper.repo import clone_repository_cached, RepoInfo, VCS
from clearml_agent.helper.repo import clone_repository_cached, RepoInfo, VCS, fix_package_import_diff_patch
from clearml_agent.helper.resource_monitor import ResourceMonitor
from clearml_agent.helper.runtime_verification import check_runtime, print_uptime_properties
from clearml_agent.session import Session
@ -1688,6 +1688,7 @@ class Worker(ServiceCommandSection):
repo_info = None
directory = None
vcs = None
script_file = None
if has_repository:
vcs, repo_info = self._get_repo_info(execution, task, venv_folder)
directory = Path(repo_info.root, execution.working_dir or ".")
@ -1698,17 +1699,26 @@ class Worker(ServiceCommandSection):
self.apply_diff(
task=task, vcs=vcs, execution_info=execution, repo_info=repo_info
)
script_file = Path(directory) / execution.entry_point
if is_literal_script:
self.log.info("found literal script in `script.diff`")
directory, script = literal_script.create_notebook_file(
task, execution, repo_info
)
execution.entry_point = script
if not has_repository:
return directory, None, None
script_file = Path(execution.entry_point)
else:
# in case of no literal script, there is not difference between empty working dir and `.`
execution.working_dir = execution.working_dir or "."
# fix our import patch (in case we have __future__)
if script_file and script_file.is_file():
fix_package_import_diff_patch(script_file.as_posix())
if is_literal_script and not has_repository:
return directory, None, None
if not directory:
assert False, "unreachable code"
return directory, vcs, repo_info

View File

@ -638,3 +638,70 @@ def clone_repository_cached(session, execution, destination):
repo_info = attr.evolve(repo_info, url=no_password_url)
return vcs, repo_info
def fix_package_import_diff_patch(entry_script_file):
# noinspection PyBroadException
try:
with open(entry_script_file, 'rt') as f:
lines = f.readlines()
except Exception:
return
# make sre we are the first import (i.e. we patched the source code)
if not lines or not lines[0].strip().startswith('from clearml ') or 'Task.init' not in lines[1]:
return
original_lines = lines
# skip over the first two lines, they are ours
# then skip over empty or comment lines
lines = [(i, line.split('#', 1)[0].rstrip()) for i, line in enumerate(lines)
if i >= 2 and line.strip('\r\n\t ') and not line.strip().startswith('#')]
# remove triple quotes ' """ '
nested_c = -1
skip_lines = []
for i, line_pair in enumerate(lines):
for _ in line_pair[1].split('"""')[1:]:
if nested_c >= 0:
skip_lines.extend(list(range(nested_c, i+1)))
nested_c = -1
else:
nested_c = i
# now select all the
lines = [pair for i, pair in enumerate(lines) if i not in skip_lines]
from_future = re.compile(r"^from[\s]*__future__[\s]*")
import_future = re.compile(r"^import[\s]*__future__[\s]*")
# test if we have __future__ import
found_index = -1
for a_i, (_, a_line) in enumerate(lines):
if found_index >= a_i:
continue
if from_future.match(a_line) or import_future.match(a_line):
found_index = a_i
# check the last import block
i, line = lines[found_index]
# wither we have \\ character at the end of the line or the line is indented
parenthesized_lines = '(' in line and ')' not in line
while line.endswith('\\') or parenthesized_lines:
found_index += 1
i, line = lines[found_index]
if ')' in line:
break
else:
break
# no imports found
if found_index < 0:
return
# now we need to move back the patched two lines
entry_line = lines[found_index][0]
new_lines = original_lines[2:entry_line + 1] + original_lines[0:2] + original_lines[entry_line + 1:]
# noinspection PyBroadException
try:
with open(entry_script_file, 'wt') as f:
f.writelines(new_lines)
except Exception:
return