Black formatting

This commit is contained in:
allegroai 2024-08-18 10:36:42 +03:00
parent cf1178ff5f
commit a9819416fd

View File

@ -3,9 +3,9 @@ import json
import os
import re
import tempfile
from sys import platform
from functools import reduce
from logging import getLogger
from sys import platform
from typing import Optional, Sequence, Union, Tuple, List, Callable, Dict, Any
from pathlib2 import Path
@ -27,28 +27,28 @@ class CreateAndPopulate(object):
)
def __init__(
self,
project_name=None, # type: Optional[str]
task_name=None, # type: Optional[str]
task_type=None, # type: Optional[str]
repo=None, # type: Optional[str]
branch=None, # type: Optional[str]
commit=None, # type: Optional[str]
script=None, # type: Optional[str]
working_directory=None, # type: Optional[str]
module=None, # type: Optional[str]
packages=None, # type: Optional[Union[bool, Sequence[str]]]
requirements_file=None, # type: Optional[Union[str, Path]]
docker=None, # type: Optional[str]
docker_args=None, # type: Optional[str]
docker_bash_setup_script=None, # type: Optional[str]
output_uri=None, # type: Optional[str]
base_task_id=None, # type: Optional[str]
add_task_init_call=True, # type: bool
force_single_script_file=False, # type: bool
raise_on_missing_entries=False, # type: bool
verbose=False, # type: bool
binary=None # type: Optional[str]
self,
project_name=None, # type: Optional[str]
task_name=None, # type: Optional[str]
task_type=None, # type: Optional[str]
repo=None, # type: Optional[str]
branch=None, # type: Optional[str]
commit=None, # type: Optional[str]
script=None, # type: Optional[str]
working_directory=None, # type: Optional[str]
module=None, # type: Optional[str]
packages=None, # type: Optional[Union[bool, Sequence[str]]]
requirements_file=None, # type: Optional[Union[str, Path]]
docker=None, # type: Optional[str]
docker_args=None, # type: Optional[str]
docker_bash_setup_script=None, # type: Optional[str]
output_uri=None, # type: Optional[str]
base_task_id=None, # type: Optional[str]
add_task_init_call=True, # type: bool
force_single_script_file=False, # type: bool
raise_on_missing_entries=False, # type: bool
verbose=False, # type: bool
binary=None, # type: Optional[str]
):
# type: (...) -> None
"""
@ -106,15 +106,16 @@ class CreateAndPopulate(object):
if not script and not module:
raise ValueError("Entry point script not provided")
if not repo and not folder and (script and not Path(script).is_file()):
raise ValueError("Script file \'{}\' could not be found".format(script))
raise ValueError("Script file '{}' could not be found".format(script))
if raise_on_missing_entries and commit and branch:
raise ValueError(
"Specify either a branch/tag or specific commit id, not both (either --commit or --branch)")
if raise_on_missing_entries and not folder and working_directory and working_directory.startswith('/'):
raise ValueError("working directory \'{}\', must be relative to repository root")
"Specify either a branch/tag or specific commit id, not both (either --commit or --branch)"
)
if raise_on_missing_entries and not folder and working_directory and working_directory.startswith("/"):
raise ValueError("working directory '{}', must be relative to repository root")
if requirements_file and not Path(requirements_file).is_file():
raise ValueError("requirements file could not be found \'{}\'")
raise ValueError("requirements file could not be found '{}'")
self.folder = folder
self.commit = commit
@ -124,8 +125,9 @@ class CreateAndPopulate(object):
self.module = module
self.cwd = working_directory
assert not packages or isinstance(packages, (tuple, list, bool))
self.packages = list(packages) if packages is not None and not isinstance(packages, bool) \
else (packages or None)
self.packages = (
list(packages) if packages is not None and not isinstance(packages, bool) else (packages or None)
)
self.requirements_file = Path(requirements_file) if requirements_file else None
self.base_task_id = base_task_id
self.docker = dict(image=docker, args=docker_args, bash_script=docker_bash_setup_script)
@ -179,14 +181,20 @@ class CreateAndPopulate(object):
stand_alone_script_outside_repo = True
if not os.path.isfile(entry_point) and not stand_alone_script_outside_repo:
if (not Path(self.script).is_absolute() and not Path(self.cwd).is_absolute() and
(Path(self.folder) / self.cwd / self.script).is_file()):
if (
not Path(self.script).is_absolute()
and not Path(self.cwd).is_absolute()
and (Path(self.folder) / self.cwd / self.script).is_file()
):
entry_point = (Path(self.folder) / self.cwd / self.script).as_posix()
elif (Path(self.cwd).is_absolute() and not Path(self.script).is_absolute() and
(Path(self.cwd) / self.script).is_file()):
elif (
Path(self.cwd).is_absolute()
and not Path(self.script).is_absolute()
and (Path(self.cwd) / self.script).is_file()
):
entry_point = (Path(self.cwd) / self.script).as_posix()
else:
raise ValueError("Script entrypoint file \'{}\' could not be found".format(entry_point))
raise ValueError("Script entrypoint file '{}' could not be found".format(entry_point))
local_entry_file = entry_point
@ -215,20 +223,25 @@ class CreateAndPopulate(object):
detailed_req_report=False,
force_single_script=True,
)
if repo_info.script['diff']:
print("Warning: local git repo diff is ignored, "
"storing only the standalone script form {}".format(self.script))
repo_info.script['diff'] = a_repo_info.script['diff'] or ''
repo_info.script['entry_point'] = a_repo_info.script['entry_point']
if repo_info.script["diff"]:
print(
"Warning: local git repo diff is ignored, "
"storing only the standalone script form {}".format(self.script)
)
repo_info.script["diff"] = a_repo_info.script["diff"] or ""
repo_info.script["entry_point"] = a_repo_info.script["entry_point"]
if a_create_requirements:
repo_info['requirements'] = a_repo_info.script.get('requirements') or {}
repo_info["requirements"] = a_repo_info.script.get("requirements") or {}
# check if we have no repository and no requirements raise error
if self.raise_on_missing_entries and (not self.requirements_file and not self.packages) \
and not self.repo and (
not repo_info or not repo_info.script or not repo_info.script.get('repository')) \
and (not entry_point or not entry_point.endswith(".sh")):
raise ValueError("Standalone script detected \'{}\', but no requirements provided".format(self.script))
if (
self.raise_on_missing_entries
and (not self.requirements_file and not self.packages)
and not self.repo
and (not repo_info or not repo_info.script or not repo_info.script.get("repository"))
and (not entry_point or not entry_point.endswith(".sh"))
):
raise ValueError("Standalone script detected '{}', but no requirements provided".format(self.script))
if dry_run:
task = None
task_state = dict(
@ -237,106 +250,127 @@ class CreateAndPopulate(object):
type=str(self.task_type or Task.TaskTypes.training),
) # type: dict
if self.output_uri is not None:
task_state['output'] = dict(destination=self.output_uri)
task_state["output"] = dict(destination=self.output_uri)
else:
task_state = dict(script={})
if self.base_task_id:
if self.verbose:
print('Cloning task {}'.format(self.base_task_id))
print("Cloning task {}".format(self.base_task_id))
task = Task.clone(source_task=self.base_task_id, project=Task.get_project_id(self.project_name))
self._set_output_uri(task)
else:
# noinspection PyProtectedMember
task = Task._create(
task_name=self.task_name, project_name=self.project_name,
task_type=self.task_type or Task.TaskTypes.training)
task_name=self.task_name,
project_name=self.project_name,
task_type=self.task_type or Task.TaskTypes.training,
)
self._set_output_uri(task)
# if there is nothing to populate, return
if not any([
self.folder, self.commit, self.branch, self.repo, self.script, self.module, self.cwd,
self.packages, self.requirements_file, self.base_task_id] + (list(self.docker.values()))
):
if not any(
[
self.folder,
self.commit,
self.branch,
self.repo,
self.script,
self.module,
self.cwd,
self.packages,
self.requirements_file,
self.base_task_id,
]
+ (list(self.docker.values()))
):
return task
# clear the script section
task_state['script'] = {}
task_state["script"] = {}
if repo_info:
task_state['script']['repository'] = repo_info.script['repository']
task_state['script']['version_num'] = repo_info.script['version_num']
task_state['script']['branch'] = repo_info.script['branch']
task_state['script']['diff'] = repo_info.script['diff'] or ''
task_state['script']['working_dir'] = repo_info.script['working_dir']
task_state['script']['entry_point'] = repo_info.script['entry_point']
task_state['script']['binary'] = self.binary or ('/bin/bash' if (
(repo_info.script['entry_point'] or '').lower().strip().endswith('.sh') and
not (repo_info.script['entry_point'] or '').lower().strip().startswith('-m ')) \
else repo_info.script['binary'])
task_state['script']['requirements'] = repo_info.script.get('requirements') or {}
task_state["script"]["repository"] = repo_info.script["repository"]
task_state["script"]["version_num"] = repo_info.script["version_num"]
task_state["script"]["branch"] = repo_info.script["branch"]
task_state["script"]["diff"] = repo_info.script["diff"] or ""
task_state["script"]["working_dir"] = repo_info.script["working_dir"]
task_state["script"]["entry_point"] = repo_info.script["entry_point"]
task_state["script"]["binary"] = self.binary or (
"/bin/bash"
if (
(repo_info.script["entry_point"] or "").lower().strip().endswith(".sh")
and not (repo_info.script["entry_point"] or "").lower().strip().startswith("-m ")
)
else repo_info.script["binary"]
)
task_state["script"]["requirements"] = repo_info.script.get("requirements") or {}
if self.cwd:
cwd = self.cwd
if not Path(cwd).is_absolute():
# cwd should be relative to the repo_root, but we need the full path
# (repo_root + cwd) in order to resolve the entry point
cwd = os.path.normpath((Path(repo_info.script['repo_root']) / self.cwd).as_posix())
cwd = os.path.normpath((Path(repo_info.script["repo_root"]) / self.cwd).as_posix())
if not Path(cwd).is_dir():
# we need to leave it as is, we have no idea, and this is a repo
cwd = self.cwd
elif not Path(cwd).is_dir():
# we were passed an absolute dir and it does not exist
raise ValueError("Working directory \'{}\' could not be found".format(cwd))
raise ValueError("Working directory '{}' could not be found".format(cwd))
if self.module:
entry_point = "-m {}".format(self.module)
elif stand_alone_script_outside_repo:
# this should be relative and the temp file we generated
entry_point = repo_info.script['entry_point']
entry_point = repo_info.script["entry_point"]
else:
entry_point = os.path.normpath(
Path(repo_info.script['repo_root']) /
repo_info.script['working_dir'] / repo_info.script['entry_point']
Path(repo_info.script["repo_root"])
/ repo_info.script["working_dir"]
/ repo_info.script["entry_point"]
)
# resolve entry_point relative to the current working directory
if Path(cwd).is_absolute():
entry_point = Path(entry_point).relative_to(cwd).as_posix()
else:
entry_point = repo_info.script['entry_point']
entry_point = repo_info.script["entry_point"]
# restore cwd - make it relative to the repo_root again
if Path(cwd).is_absolute():
# now cwd is relative again
cwd = Path(cwd).relative_to(repo_info.script['repo_root']).as_posix()
cwd = Path(cwd).relative_to(repo_info.script["repo_root"]).as_posix()
# make sure we always have / (never \\)
if platform == "win32":
entry_point = entry_point.replace('\\', '/') if entry_point else ""
cwd = cwd.replace('\\', '/') if cwd else ""
entry_point = entry_point.replace("\\", "/") if entry_point else ""
cwd = cwd.replace("\\", "/") if cwd else ""
task_state['script']['entry_point'] = entry_point or ""
task_state['script']['working_dir'] = cwd or "."
task_state["script"]["entry_point"] = entry_point or ""
task_state["script"]["working_dir"] = cwd or "."
elif self.repo:
cwd = '/'.join([p for p in (self.cwd or '.').split('/') if p and p != '.'])
cwd = "/".join([p for p in (self.cwd or ".").split("/") if p and p != "."])
# normalize backslashes and remove first one
if self.module:
entry_point = "-m {}".format(self.module)
else:
entry_point = '/'.join([p for p in self.script.split('/') if p and p != '.'])
if cwd and entry_point.startswith(cwd + '/'):
entry_point = entry_point[len(cwd) + 1:]
entry_point = "/".join([p for p in self.script.split("/") if p and p != "."])
if cwd and entry_point.startswith(cwd + "/"):
entry_point = entry_point[len(cwd) + 1 :]
task_state['script']['repository'] = self.repo
task_state['script']['version_num'] = self.commit or None
task_state['script']['branch'] = self.branch or None
task_state['script']['diff'] = ''
task_state['script']['working_dir'] = cwd or '.'
task_state['script']['entry_point'] = entry_point or ""
task_state["script"]["repository"] = self.repo
task_state["script"]["version_num"] = self.commit or None
task_state["script"]["branch"] = self.branch or None
task_state["script"]["diff"] = ""
task_state["script"]["working_dir"] = cwd or "."
task_state["script"]["entry_point"] = entry_point or ""
if self.script and Path(self.script).is_file() and (
self.force_single_script_file or Path(self.script).is_absolute()):
if (
self.script
and Path(self.script).is_file()
and (self.force_single_script_file or Path(self.script).is_absolute())
):
self.force_single_script_file = True
create_requirements = self.packages is True
repo_info, requirements = ScriptInfo.get(
@ -349,29 +383,37 @@ class CreateAndPopulate(object):
detailed_req_report=False,
force_single_script=True,
)
task_state['script']['binary'] = self.binary or ('/bin/bash' if (
(repo_info.script['entry_point'] or '').lower().strip().endswith('.sh') and
not (repo_info.script['entry_point'] or '').lower().strip().startswith('-m ')) \
else repo_info.script['binary'])
task_state['script']['diff'] = repo_info.script['diff'] or ''
task_state['script']['entry_point'] = repo_info.script['entry_point']
task_state["script"]["binary"] = self.binary or (
"/bin/bash"
if (
(repo_info.script["entry_point"] or "").lower().strip().endswith(".sh")
and not (repo_info.script["entry_point"] or "").lower().strip().startswith("-m ")
)
else repo_info.script["binary"]
)
task_state["script"]["diff"] = repo_info.script["diff"] or ""
task_state["script"]["entry_point"] = repo_info.script["entry_point"]
if create_requirements:
task_state['script']['requirements'] = repo_info.script.get('requirements') or {}
task_state["script"]["requirements"] = repo_info.script.get("requirements") or {}
else:
if self.binary:
task_state["script"]["binary"] = self.binary
elif entry_point and entry_point.lower().strip().endswith(".sh") and not \
entry_point.lower().strip().startswith("-m"):
elif (
entry_point
and entry_point.lower().strip().endswith(".sh")
and not entry_point.lower().strip().startswith("-m")
):
task_state["script"]["binary"] = "/bin/bash"
else:
# standalone task
task_state['script']['entry_point'] = self.script if self.script else \
("-m {}".format(self.module) if self.module else "")
task_state['script']['working_dir'] = '.'
task_state["script"]["entry_point"] = (
self.script if self.script else ("-m {}".format(self.module) if self.module else "")
)
task_state["script"]["working_dir"] = "."
# update requirements
reqs = []
if self.requirements_file:
with open(self.requirements_file.as_posix(), 'rt') as f:
with open(self.requirements_file.as_posix(), "rt") as f:
reqs = [line.strip() for line in f.readlines()]
if self.packages and self.packages is not True:
reqs += self.packages
@ -379,66 +421,76 @@ class CreateAndPopulate(object):
# make sure we have clearml.
clearml_found = False
for line in reqs:
if line.strip().startswith('#'):
if line.strip().startswith("#"):
continue
package = reduce(lambda a, b: a.split(b)[0], "#;@=~<>[", line).strip()
if package == 'clearml':
if package == "clearml":
clearml_found = True
break
if not clearml_found:
reqs.append('clearml')
task_state['script']['requirements'] = {'pip': '\n'.join(reqs)}
elif not self.repo and repo_info and not repo_info.script.get('requirements'):
reqs.append("clearml")
task_state["script"]["requirements"] = {"pip": "\n".join(reqs)}
elif not self.repo and repo_info and not repo_info.script.get("requirements"):
# we are in local mode, make sure we have "requirements.txt" it is a must
reqs_txt_file = Path(repo_info.script['repo_root']) / "requirements.txt"
poetry_toml_file = Path(repo_info.script['repo_root']) / "pyproject.toml"
reqs_txt_file = Path(repo_info.script["repo_root"]) / "requirements.txt"
poetry_toml_file = Path(repo_info.script["repo_root"]) / "pyproject.toml"
if self.raise_on_missing_entries and not reqs_txt_file.is_file() and not poetry_toml_file.is_file():
raise ValueError(
"requirements.txt not found [{}] "
"Use --requirements or --packages".format(reqs_txt_file.as_posix()))
"Use --requirements or --packages".format(reqs_txt_file.as_posix())
)
if self.add_task_init_call:
script_entry = ('/' + task_state['script'].get('working_dir', '.')
+ '/' + task_state['script']['entry_point'])
script_entry = (
"/" + task_state["script"].get("working_dir", ".") + "/" + task_state["script"]["entry_point"]
)
if platform == "win32":
script_entry = os.path.normpath(script_entry).replace('\\', '/')
script_entry = os.path.normpath(script_entry).replace("\\", "/")
else:
script_entry = os.path.abspath(script_entry)
idx_a = 0
lines = None
# find the right entry for the patch if we have a local file (basically after __future__
if (local_entry_file and not stand_alone_script_outside_repo and not self.module and
str(local_entry_file).lower().endswith(".py")):
with open(local_entry_file, 'rt') as f:
if (
local_entry_file
and not stand_alone_script_outside_repo
and not self.module
and str(local_entry_file).lower().endswith(".py")
):
with open(local_entry_file, "rt") as f:
lines = f.readlines()
future_found = self._locate_future_import(lines)
if future_found >= 0:
idx_a = future_found + 1
task_init_patch = ''
if ((self.repo or task_state.get('script', {}).get('repository')) and
not self.force_single_script_file and not stand_alone_script_outside_repo):
task_init_patch = ""
if (
(self.repo or task_state.get("script", {}).get("repository"))
and not self.force_single_script_file
and not stand_alone_script_outside_repo
):
# if we do not have requirements, add clearml to the requirements.txt
if not reqs:
task_init_patch += \
"diff --git a/requirements.txt b/requirements.txt\n" \
"--- a/requirements.txt\n" \
"+++ b/requirements.txt\n" \
"@@ -0,0 +1,1 @@\n" \
task_init_patch += (
"diff --git a/requirements.txt b/requirements.txt\n"
"--- a/requirements.txt\n"
"+++ b/requirements.txt\n"
"@@ -0,0 +1,1 @@\n"
"+clearml\n"
)
# Add Task.init call
if not self.module and script_entry and str(script_entry).lower().endswith(".py"):
task_init_patch += \
"diff --git a{script_entry} b{script_entry}\n" \
"--- a{script_entry}\n" \
"+++ b{script_entry}\n" \
"@@ -{idx_a},0 +{idx_b},4 @@\n" \
"+try: from allegroai import Task\n" \
"+except ImportError: from clearml import Task\n" \
"+(__name__ != \"__main__\") or Task.init()\n" \
"+\n".format(
script_entry=script_entry, idx_a=idx_a, idx_b=idx_a + 1)
task_init_patch += (
"diff --git a{script_entry} b{script_entry}\n"
"--- a{script_entry}\n"
"+++ b{script_entry}\n"
"@@ -{idx_a},0 +{idx_b},4 @@\n"
"+try: from allegroai import Task\n"
"+except ImportError: from clearml import Task\n"
'+(__name__ != "__main__") or Task.init()\n'
"+\n".format(script_entry=script_entry, idx_a=idx_a, idx_b=idx_a + 1)
)
elif self.module:
# if we are here, do nothing
pass
@ -449,57 +501,62 @@ class CreateAndPopulate(object):
"except ImportError: from clearml import Task\n",
'(__name__ != "__main__") or Task.init()\n\n',
]
task_state['script']['diff'] = ''.join(lines[:idx_a] + init_lines + lines[idx_a:])
task_state["script"]["diff"] = "".join(lines[:idx_a] + init_lines + lines[idx_a:])
# no need to add anything, we patched it.
task_init_patch = ""
elif str(script_entry or "").lower().endswith(".py"):
# Add Task.init call
# if we are here it means we do not have a git diff, but a single script file
task_init_patch += \
"try: from allegroai import Task\n" \
"except ImportError: from clearml import Task\n" \
"(__name__ != \"__main__\") or Task.init()\n\n"
task_state['script']['diff'] = task_init_patch + task_state['script'].get('diff', '')
task_init_patch += (
"try: from allegroai import Task\n"
"except ImportError: from clearml import Task\n"
'(__name__ != "__main__") or Task.init()\n\n'
)
task_state["script"]["diff"] = task_init_patch + task_state["script"].get("diff", "")
task_init_patch = ""
# make sure we add the diff at the end of the current diff
task_state['script']['diff'] = task_state['script'].get('diff', '')
if task_state['script']['diff'] and not task_state['script']['diff'].endswith('\n'):
task_state['script']['diff'] += '\n'
task_state['script']['diff'] += task_init_patch
task_state["script"]["diff"] = task_state["script"].get("diff", "")
if task_state["script"]["diff"] and not task_state["script"]["diff"].endswith("\n"):
task_state["script"]["diff"] += "\n"
task_state["script"]["diff"] += task_init_patch
# set base docker image if provided
if self.docker:
if dry_run:
task_state['container'] = dict(
image=self.docker.get('image') or '',
arguments=self.docker.get('args') or '',
setup_shell_script=self.docker.get('bash_script') or '',
task_state["container"] = dict(
image=self.docker.get("image") or "",
arguments=self.docker.get("args") or "",
setup_shell_script=self.docker.get("bash_script") or "",
)
else:
task.set_base_docker(
docker_image=self.docker.get('image'),
docker_arguments=self.docker.get('args'),
docker_setup_bash_script=self.docker.get('bash_script'),
docker_image=self.docker.get("image"),
docker_arguments=self.docker.get("args"),
docker_setup_bash_script=self.docker.get("bash_script"),
)
if self.verbose:
if task_state['script']['repository']:
repo_details = {k: v for k, v in task_state['script'].items()
if v and k not in ('diff', 'requirements', 'binary')}
print('Repository Detected\n{}'.format(json.dumps(repo_details, indent=2)))
if task_state["script"]["repository"]:
repo_details = {
k: v for k, v in task_state["script"].items() if v and k not in ("diff", "requirements", "binary")
}
print("Repository Detected\n{}".format(json.dumps(repo_details, indent=2)))
else:
print('Standalone script detected\n Script: {}'.format(self.script))
print("Standalone script detected\n Script: {}".format(self.script))
if task_state['script'].get('requirements') and \
task_state['script']['requirements'].get('pip'):
print('Requirements:{}{}'.format(
'\n Using requirements.txt: {}'.format(
self.requirements_file.as_posix()) if self.requirements_file else '',
'\n {}Packages: {}'.format('Additional ' if self.requirements_file else '', self.packages)
if self.packages else ''
))
if task_state["script"].get("requirements") and task_state["script"]["requirements"].get("pip"):
print(
"Requirements:{}{}".format(
"\n Using requirements.txt: {}".format(self.requirements_file.as_posix())
if self.requirements_file
else "",
"\n {}Packages: {}".format("Additional " if self.requirements_file else "", self.packages)
if self.packages
else "",
)
)
if self.docker:
print('Base docker image: {}'.format(self.docker))
print("Base docker image: {}".format(self.docker))
if dry_run:
return task_state
@ -538,18 +595,17 @@ class CreateAndPopulate(object):
args_list.append(a)
continue
try:
parts = a.split('=', 1)
parts = a.split("=", 1)
assert len(parts) == 2
args_list.append(parts)
except Exception:
raise ValueError(
"Failed parsing argument \'{}\', arguments must be in \'<key>=<value>\' format")
raise ValueError("Failed parsing argument '{}', arguments must be in '<key>=<value>' format")
if not self.task:
return
task_params = self.task.get_parameters()
args_list = {'Args/{}'.format(k): v for k, v in args_list}
args_list = {"Args/{}".format(k): v for k, v in args_list}
task_params.update(args_list)
self.task.set_parameters(task_params)
@ -569,8 +625,11 @@ class CreateAndPopulate(object):
"""
# skip over the first two lines, they are ours
# then skip over empty or comment lines
lines = [(i, line.split('#', 1)[0].rstrip()) for i, line in enumerate(lines)
if line.strip('\r\n\t ') and not line.strip().startswith('#')]
lines = [
(i, line.split("#", 1)[0].rstrip())
for i, line in enumerate(lines)
if line.strip("\r\n\t ") and not line.strip().startswith("#")
]
# remove triple quotes ' """ '
nested_c = -1
@ -597,11 +656,11 @@ class CreateAndPopulate(object):
# check the last import block
i, line = lines[found_index]
# wither we have \\ character at the end of the line or the line is indented
parenthesized_lines = '(' in line and ')' not in line
while line.endswith('\\') or parenthesized_lines:
parenthesized_lines = "(" in line and ")" not in line
while line.endswith("\\") or parenthesized_lines:
found_index += 1
i, line = lines[found_index]
if ')' in line:
if ")" in line:
break
else:
@ -671,33 +730,33 @@ if __name__ == '__main__':
@classmethod
def create_task_from_function(
cls,
a_function, # type: Callable
function_kwargs=None, # type: Optional[Dict[str, Any]]
function_input_artifacts=None, # type: Optional[Dict[str, str]]
function_return=None, # type: Optional[List[str]]
project_name=None, # type: Optional[str]
task_name=None, # type: Optional[str]
task_type=None, # type: Optional[str]
auto_connect_frameworks=None, # type: Optional[dict]
auto_connect_arg_parser=None, # type: Optional[dict]
repo=None, # type: Optional[str]
branch=None, # type: Optional[str]
commit=None, # type: Optional[str]
packages=None, # type: Optional[Union[str, Sequence[str]]]
docker=None, # type: Optional[str]
docker_args=None, # type: Optional[str]
docker_bash_setup_script=None, # type: Optional[str]
output_uri=None, # type: Optional[str]
helper_functions=None, # type: Optional[Sequence[Callable]]
dry_run=False, # type: bool
task_template_header=None, # type: Optional[str]
artifact_serialization_function=None, # type: Optional[Callable[[Any], Union[bytes, bytearray]]]
artifact_deserialization_function=None, # type: Optional[Callable[[bytes], Any]]
_sanitize_function=None, # type: Optional[Callable[[str], str]]
_sanitize_helper_functions=None, # type: Optional[Callable[[str], str]]
skip_global_imports=False, # type: bool
working_dir=None # type: Optional[str]
cls,
a_function, # type: Callable
function_kwargs=None, # type: Optional[Dict[str, Any]]
function_input_artifacts=None, # type: Optional[Dict[str, str]]
function_return=None, # type: Optional[List[str]]
project_name=None, # type: Optional[str]
task_name=None, # type: Optional[str]
task_type=None, # type: Optional[str]
auto_connect_frameworks=None, # type: Optional[dict]
auto_connect_arg_parser=None, # type: Optional[dict]
repo=None, # type: Optional[str]
branch=None, # type: Optional[str]
commit=None, # type: Optional[str]
packages=None, # type: Optional[Union[str, Sequence[str]]]
docker=None, # type: Optional[str]
docker_args=None, # type: Optional[str]
docker_bash_setup_script=None, # type: Optional[str]
output_uri=None, # type: Optional[str]
helper_functions=None, # type: Optional[Sequence[Callable]]
dry_run=False, # type: bool
task_template_header=None, # type: Optional[str]
artifact_serialization_function=None, # type: Optional[Callable[[Any], Union[bytes, bytearray]]]
artifact_deserialization_function=None, # type: Optional[Callable[[bytes], Any]]
_sanitize_function=None, # type: Optional[Callable[[str], str]]
_sanitize_helper_functions=None, # type: Optional[Callable[[str], str]]
skip_global_imports=False, # type: bool
working_dir=None, # type: Optional[str]
):
# type: (...) -> Optional[Dict, Task]
"""
@ -793,8 +852,8 @@ if __name__ == '__main__':
if auto_connect_arg_parser is None:
auto_connect_arg_parser = True
assert (not auto_connect_frameworks or isinstance(auto_connect_frameworks, (bool, dict)))
assert (not auto_connect_arg_parser or isinstance(auto_connect_arg_parser, (bool, dict)))
assert not auto_connect_frameworks or isinstance(auto_connect_frameworks, (bool, dict))
assert not auto_connect_arg_parser or isinstance(auto_connect_arg_parser, (bool, dict))
function_source, function_name = CreateFromFunction.__extract_function_information(
a_function, sanitize_function=_sanitize_function, skip_global_imports=skip_global_imports
@ -819,9 +878,9 @@ if __name__ == '__main__':
function_input_artifacts = function_input_artifacts or dict()
# verify artifact kwargs:
if not all(len(v.split('.', 1)) == 2 for v in function_input_artifacts.values()):
if not all(len(v.split(".", 1)) == 2 for v in function_input_artifacts.values()):
raise ValueError(
'function_input_artifacts={}, it must in the format: '
"function_input_artifacts={}, it must in the format: "
'{{"argument": "task_id.artifact_name"}}'.format(function_input_artifacts)
)
inspect_args = None
@ -835,16 +894,20 @@ if __name__ == '__main__':
# adjust the defaults so they match the args (match from the end)
if inspect_defaults_vals and len(inspect_defaults_vals) != len(inspect_defaults_args):
inspect_defaults_args = inspect_defaults_args[-len(inspect_defaults_vals):]
inspect_defaults_args = inspect_defaults_args[-len(inspect_defaults_vals) :]
if inspect_defaults_vals and len(inspect_defaults_vals) != len(inspect_defaults_args):
getLogger().warning(
'Ignoring default argument values: '
'could not find all default valued for: \'{}\''.format(function_name))
"Ignoring default argument values: "
"could not find all default valued for: '{}'".format(function_name)
)
inspect_defaults_vals = []
function_kwargs = {str(k): v for k, v in zip(inspect_defaults_args, inspect_defaults_vals)} \
if inspect_defaults_vals else {str(k): None for k in inspect_defaults_args}
function_kwargs = (
{str(k): v for k, v in zip(inspect_defaults_args, inspect_defaults_vals)}
if inspect_defaults_vals
else {str(k): None for k in inspect_defaults_args}
)
if function_kwargs:
if not inspect_args:
@ -853,8 +916,10 @@ if __name__ == '__main__':
if inspect_args.annotations:
supported_types = _Arguments.get_supported_types()
function_kwargs_types = {
str(k): str(inspect_args.annotations[k].__name__) for k in inspect_args.annotations
if inspect_args.annotations[k] in supported_types}
str(k): str(inspect_args.annotations[k].__name__)
for k in inspect_args.annotations
if inspect_args.annotations[k] in supported_types
}
task_template = cls.task_template.format(
header=task_template_header or cls.default_task_template_header,
@ -871,11 +936,11 @@ if __name__ == '__main__':
artifact_serialization_function_source=artifact_serialization_function_source,
artifact_serialization_function_name=artifact_serialization_function_name,
artifact_deserialization_function_source=artifact_deserialization_function_source,
artifact_deserialization_function_name=artifact_deserialization_function_name
artifact_deserialization_function_name=artifact_deserialization_function_name,
)
temp_dir = repo if repo and os.path.isdir(repo) else None
with tempfile.NamedTemporaryFile('w', suffix='.py', dir=temp_dir) as temp_file:
with tempfile.NamedTemporaryFile("w", suffix=".py", dir=temp_dir) as temp_file:
temp_file.write(task_template)
temp_file.flush()
@ -899,38 +964,53 @@ if __name__ == '__main__':
docker_bash_setup_script=docker_bash_setup_script,
output_uri=output_uri,
add_task_init_call=False,
working_directory=working_dir
working_directory=working_dir,
)
entry_point = '{}.py'.format(function_name)
entry_point = "{}.py".format(function_name)
task = populate.create_task(dry_run=dry_run)
if dry_run:
task['script']['diff'] = task_template
task['script']['entry_point'] = entry_point
task['script']['working_dir'] = working_dir or '.'
task['hyperparams'] = {
task["script"]["diff"] = task_template
task["script"]["entry_point"] = entry_point
task["script"]["working_dir"] = working_dir or "."
task["hyperparams"] = {
cls.kwargs_section: {
k: dict(section=cls.kwargs_section, name=k,
value=str(v) if v is not None else '', type=function_kwargs_types.get(k, None))
k: dict(
section=cls.kwargs_section,
name=k,
value=str(v) if v is not None else "",
type=function_kwargs_types.get(k, None),
)
for k, v in (function_kwargs or {}).items()
},
cls.input_artifact_section: {
k: dict(section=cls.input_artifact_section, name=k, value=str(v) if v is not None else '')
k: dict(section=cls.input_artifact_section, name=k, value=str(v) if v is not None else "")
for k, v in (function_input_artifacts or {}).items()
}
},
}
else:
task.update_task(task_data={
'script': task.data.script.to_dict().update(
{'entry_point': entry_point, 'working_dir': '.', 'diff': task_template})})
hyper_parameters = {'{}/{}'.format(cls.kwargs_section, k): str(v) for k, v in function_kwargs} \
if function_kwargs else {}
hyper_parameters.update(
{'{}/{}'.format(cls.input_artifact_section, k): str(v) for k, v in function_input_artifacts}
if function_input_artifacts else {}
task.update_task(
task_data={
"script": task.data.script.to_dict().update(
{"entry_point": entry_point, "working_dir": ".", "diff": task_template}
)
}
)
hyper_parameters = (
{"{}/{}".format(cls.kwargs_section, k): str(v) for k, v in function_kwargs}
if function_kwargs
else {}
)
hyper_parameters.update(
{"{}/{}".format(cls.input_artifact_section, k): str(v) for k, v in function_input_artifacts}
if function_input_artifacts
else {}
)
__function_kwargs_types = (
{"{}/{}".format(cls.kwargs_section, k): v for k, v in function_kwargs_types}
if function_kwargs_types
else None
)
__function_kwargs_types = {'{}/{}'.format(cls.kwargs_section, k): v for k, v in function_kwargs_types} \
if function_kwargs_types else None
task.set_parameters(hyper_parameters, __parameters_types=__function_kwargs_types)
return task
@ -940,6 +1020,7 @@ if __name__ == '__main__':
# type: (str) -> str
try:
import ast
try:
# available in Python3.9+
from ast import unparse
@ -950,8 +1031,8 @@ if __name__ == '__main__':
# noinspection PyBroadException
try:
class TypeHintRemover(ast.NodeTransformer):
class TypeHintRemover(ast.NodeTransformer):
def visit_FunctionDef(self, node):
# remove the return type definition
node.returns = None
@ -967,7 +1048,7 @@ if __name__ == '__main__':
# and import statements from 'typing'
transformed = TypeHintRemover().visit(parsed_source)
# convert the AST back to source code
return unparse(transformed).lstrip('\n')
return unparse(transformed).lstrip("\n")
except Exception:
# just in case we failed parsing.
return function_source
@ -975,14 +1056,16 @@ if __name__ == '__main__':
@staticmethod
def __extract_imports(func):
def add_import_guard(import_):
return ("try:\n "
+ import_.replace("\n", "\n ", import_.count("\n") - 1)
+ "\nexcept Exception as e:\n print('Import error: ' + str(e))\n"
)
return (
"try:\n "
+ import_.replace("\n", "\n ", import_.count("\n") - 1)
+ "\nexcept Exception as e:\n print('Import error: ' + str(e))\n"
)
# noinspection PyBroadException
try:
import ast
func_module = inspect.getmodule(func)
source = inspect.getsource(func_module)
parsed_source = ast.parse(source)
@ -1006,7 +1089,7 @@ if __name__ == '__main__':
imports = [add_import_guard(import_) for import_ in imports]
return "\n".join(imports)
except Exception as e:
getLogger().warning('Could not fetch function imports: {}'.format(e))
getLogger().warning("Could not fetch function imports: {}".format(e))
return ""
@staticmethod
@ -1046,7 +1129,7 @@ if __name__ == '__main__':
result.append(f.name)
except Exception as e:
name = getattr(module, "__name__", module)
getLogger().warning('Could not fetch function declared in {}: {}'.format(name, e))
getLogger().warning("Could not fetch function declared in {}: {}".format(name, e))
return result
@staticmethod
@ -1058,12 +1141,11 @@ if __name__ == '__main__':
func_members_dict = dict(inspect.getmembers(original_module, inspect.isfunction))
except Exception as e:
name = getattr(original_module, "__name__", original_module)
getLogger().warning('Could not fetch functions from {}: {}'.format(name, e))
getLogger().warning("Could not fetch functions from {}: {}".format(name, e))
func_members_dict = {}
decorated_func = CreateFromFunction._deep_extract_wrapped(func)
decorated_func_source = CreateFromFunction.__sanitize(
inspect.getsource(decorated_func),
sanitize_function=sanitize_function
inspect.getsource(decorated_func), sanitize_function=sanitize_function
)
try:
import ast
@ -1083,14 +1165,16 @@ if __name__ == '__main__':
decorator_func = func_members_dict.get(name)
if name not in func_members or not decorator_func:
continue
decorated_func_source = CreateFromFunction.__get_source_with_decorators(
decorator_func,
original_module=original_module,
sanitize_function=sanitize_function
) + "\n\n" + decorated_func_source
decorated_func_source = (
CreateFromFunction.__get_source_with_decorators(
decorator_func, original_module=original_module, sanitize_function=sanitize_function
)
+ "\n\n"
+ decorated_func_source
)
break
except Exception as e:
getLogger().warning('Could not fetch full definition of function {}: {}'.format(func.__name__, e))
getLogger().warning("Could not fetch full definition of function {}: {}".format(func.__name__, e))
return decorated_func_source
@staticmethod