Black formatting

This commit is contained in:
allegroai 2024-08-18 10:36:42 +03:00
parent cf1178ff5f
commit a9819416fd

View File

@ -3,9 +3,9 @@ import json
import os import os
import re import re
import tempfile import tempfile
from sys import platform
from functools import reduce from functools import reduce
from logging import getLogger from logging import getLogger
from sys import platform
from typing import Optional, Sequence, Union, Tuple, List, Callable, Dict, Any from typing import Optional, Sequence, Union, Tuple, List, Callable, Dict, Any
from pathlib2 import Path from pathlib2 import Path
@ -27,28 +27,28 @@ class CreateAndPopulate(object):
) )
def __init__( def __init__(
self, self,
project_name=None, # type: Optional[str] project_name=None, # type: Optional[str]
task_name=None, # type: Optional[str] task_name=None, # type: Optional[str]
task_type=None, # type: Optional[str] task_type=None, # type: Optional[str]
repo=None, # type: Optional[str] repo=None, # type: Optional[str]
branch=None, # type: Optional[str] branch=None, # type: Optional[str]
commit=None, # type: Optional[str] commit=None, # type: Optional[str]
script=None, # type: Optional[str] script=None, # type: Optional[str]
working_directory=None, # type: Optional[str] working_directory=None, # type: Optional[str]
module=None, # type: Optional[str] module=None, # type: Optional[str]
packages=None, # type: Optional[Union[bool, Sequence[str]]] packages=None, # type: Optional[Union[bool, Sequence[str]]]
requirements_file=None, # type: Optional[Union[str, Path]] requirements_file=None, # type: Optional[Union[str, Path]]
docker=None, # type: Optional[str] docker=None, # type: Optional[str]
docker_args=None, # type: Optional[str] docker_args=None, # type: Optional[str]
docker_bash_setup_script=None, # type: Optional[str] docker_bash_setup_script=None, # type: Optional[str]
output_uri=None, # type: Optional[str] output_uri=None, # type: Optional[str]
base_task_id=None, # type: Optional[str] base_task_id=None, # type: Optional[str]
add_task_init_call=True, # type: bool add_task_init_call=True, # type: bool
force_single_script_file=False, # type: bool force_single_script_file=False, # type: bool
raise_on_missing_entries=False, # type: bool raise_on_missing_entries=False, # type: bool
verbose=False, # type: bool verbose=False, # type: bool
binary=None # type: Optional[str] binary=None, # type: Optional[str]
): ):
# type: (...) -> None # type: (...) -> None
""" """
@ -106,15 +106,16 @@ class CreateAndPopulate(object):
if not script and not module: if not script and not module:
raise ValueError("Entry point script not provided") raise ValueError("Entry point script not provided")
if not repo and not folder and (script and not Path(script).is_file()): if not repo and not folder and (script and not Path(script).is_file()):
raise ValueError("Script file \'{}\' could not be found".format(script)) raise ValueError("Script file '{}' could not be found".format(script))
if raise_on_missing_entries and commit and branch: if raise_on_missing_entries and commit and branch:
raise ValueError( raise ValueError(
"Specify either a branch/tag or specific commit id, not both (either --commit or --branch)") "Specify either a branch/tag or specific commit id, not both (either --commit or --branch)"
if raise_on_missing_entries and not folder and working_directory and working_directory.startswith('/'): )
raise ValueError("working directory \'{}\', must be relative to repository root") if raise_on_missing_entries and not folder and working_directory and working_directory.startswith("/"):
raise ValueError("working directory '{}', must be relative to repository root")
if requirements_file and not Path(requirements_file).is_file(): if requirements_file and not Path(requirements_file).is_file():
raise ValueError("requirements file could not be found \'{}\'") raise ValueError("requirements file could not be found '{}'")
self.folder = folder self.folder = folder
self.commit = commit self.commit = commit
@ -124,8 +125,9 @@ class CreateAndPopulate(object):
self.module = module self.module = module
self.cwd = working_directory self.cwd = working_directory
assert not packages or isinstance(packages, (tuple, list, bool)) assert not packages or isinstance(packages, (tuple, list, bool))
self.packages = list(packages) if packages is not None and not isinstance(packages, bool) \ self.packages = (
else (packages or None) list(packages) if packages is not None and not isinstance(packages, bool) else (packages or None)
)
self.requirements_file = Path(requirements_file) if requirements_file else None self.requirements_file = Path(requirements_file) if requirements_file else None
self.base_task_id = base_task_id self.base_task_id = base_task_id
self.docker = dict(image=docker, args=docker_args, bash_script=docker_bash_setup_script) self.docker = dict(image=docker, args=docker_args, bash_script=docker_bash_setup_script)
@ -179,14 +181,20 @@ class CreateAndPopulate(object):
stand_alone_script_outside_repo = True stand_alone_script_outside_repo = True
if not os.path.isfile(entry_point) and not stand_alone_script_outside_repo: if not os.path.isfile(entry_point) and not stand_alone_script_outside_repo:
if (not Path(self.script).is_absolute() and not Path(self.cwd).is_absolute() and if (
(Path(self.folder) / self.cwd / self.script).is_file()): not Path(self.script).is_absolute()
and not Path(self.cwd).is_absolute()
and (Path(self.folder) / self.cwd / self.script).is_file()
):
entry_point = (Path(self.folder) / self.cwd / self.script).as_posix() entry_point = (Path(self.folder) / self.cwd / self.script).as_posix()
elif (Path(self.cwd).is_absolute() and not Path(self.script).is_absolute() and elif (
(Path(self.cwd) / self.script).is_file()): Path(self.cwd).is_absolute()
and not Path(self.script).is_absolute()
and (Path(self.cwd) / self.script).is_file()
):
entry_point = (Path(self.cwd) / self.script).as_posix() entry_point = (Path(self.cwd) / self.script).as_posix()
else: else:
raise ValueError("Script entrypoint file \'{}\' could not be found".format(entry_point)) raise ValueError("Script entrypoint file '{}' could not be found".format(entry_point))
local_entry_file = entry_point local_entry_file = entry_point
@ -215,20 +223,25 @@ class CreateAndPopulate(object):
detailed_req_report=False, detailed_req_report=False,
force_single_script=True, force_single_script=True,
) )
if repo_info.script['diff']: if repo_info.script["diff"]:
print("Warning: local git repo diff is ignored, " print(
"storing only the standalone script form {}".format(self.script)) "Warning: local git repo diff is ignored, "
repo_info.script['diff'] = a_repo_info.script['diff'] or '' "storing only the standalone script form {}".format(self.script)
repo_info.script['entry_point'] = a_repo_info.script['entry_point'] )
repo_info.script["diff"] = a_repo_info.script["diff"] or ""
repo_info.script["entry_point"] = a_repo_info.script["entry_point"]
if a_create_requirements: if a_create_requirements:
repo_info['requirements'] = a_repo_info.script.get('requirements') or {} repo_info["requirements"] = a_repo_info.script.get("requirements") or {}
# check if we have no repository and no requirements raise error # check if we have no repository and no requirements raise error
if self.raise_on_missing_entries and (not self.requirements_file and not self.packages) \ if (
and not self.repo and ( self.raise_on_missing_entries
not repo_info or not repo_info.script or not repo_info.script.get('repository')) \ and (not self.requirements_file and not self.packages)
and (not entry_point or not entry_point.endswith(".sh")): and not self.repo
raise ValueError("Standalone script detected \'{}\', but no requirements provided".format(self.script)) and (not repo_info or not repo_info.script or not repo_info.script.get("repository"))
and (not entry_point or not entry_point.endswith(".sh"))
):
raise ValueError("Standalone script detected '{}', but no requirements provided".format(self.script))
if dry_run: if dry_run:
task = None task = None
task_state = dict( task_state = dict(
@ -237,106 +250,127 @@ class CreateAndPopulate(object):
type=str(self.task_type or Task.TaskTypes.training), type=str(self.task_type or Task.TaskTypes.training),
) # type: dict ) # type: dict
if self.output_uri is not None: if self.output_uri is not None:
task_state['output'] = dict(destination=self.output_uri) task_state["output"] = dict(destination=self.output_uri)
else: else:
task_state = dict(script={}) task_state = dict(script={})
if self.base_task_id: if self.base_task_id:
if self.verbose: if self.verbose:
print('Cloning task {}'.format(self.base_task_id)) print("Cloning task {}".format(self.base_task_id))
task = Task.clone(source_task=self.base_task_id, project=Task.get_project_id(self.project_name)) task = Task.clone(source_task=self.base_task_id, project=Task.get_project_id(self.project_name))
self._set_output_uri(task) self._set_output_uri(task)
else: else:
# noinspection PyProtectedMember # noinspection PyProtectedMember
task = Task._create( task = Task._create(
task_name=self.task_name, project_name=self.project_name, task_name=self.task_name,
task_type=self.task_type or Task.TaskTypes.training) project_name=self.project_name,
task_type=self.task_type or Task.TaskTypes.training,
)
self._set_output_uri(task) self._set_output_uri(task)
# if there is nothing to populate, return # if there is nothing to populate, return
if not any([ if not any(
self.folder, self.commit, self.branch, self.repo, self.script, self.module, self.cwd, [
self.packages, self.requirements_file, self.base_task_id] + (list(self.docker.values())) self.folder,
): self.commit,
self.branch,
self.repo,
self.script,
self.module,
self.cwd,
self.packages,
self.requirements_file,
self.base_task_id,
]
+ (list(self.docker.values()))
):
return task return task
# clear the script section # clear the script section
task_state['script'] = {} task_state["script"] = {}
if repo_info: if repo_info:
task_state['script']['repository'] = repo_info.script['repository'] task_state["script"]["repository"] = repo_info.script["repository"]
task_state['script']['version_num'] = repo_info.script['version_num'] task_state["script"]["version_num"] = repo_info.script["version_num"]
task_state['script']['branch'] = repo_info.script['branch'] task_state["script"]["branch"] = repo_info.script["branch"]
task_state['script']['diff'] = repo_info.script['diff'] or '' task_state["script"]["diff"] = repo_info.script["diff"] or ""
task_state['script']['working_dir'] = repo_info.script['working_dir'] task_state["script"]["working_dir"] = repo_info.script["working_dir"]
task_state['script']['entry_point'] = repo_info.script['entry_point'] task_state["script"]["entry_point"] = repo_info.script["entry_point"]
task_state['script']['binary'] = self.binary or ('/bin/bash' if ( task_state["script"]["binary"] = self.binary or (
(repo_info.script['entry_point'] or '').lower().strip().endswith('.sh') and "/bin/bash"
not (repo_info.script['entry_point'] or '').lower().strip().startswith('-m ')) \ if (
else repo_info.script['binary']) (repo_info.script["entry_point"] or "").lower().strip().endswith(".sh")
task_state['script']['requirements'] = repo_info.script.get('requirements') or {} and not (repo_info.script["entry_point"] or "").lower().strip().startswith("-m ")
)
else repo_info.script["binary"]
)
task_state["script"]["requirements"] = repo_info.script.get("requirements") or {}
if self.cwd: if self.cwd:
cwd = self.cwd cwd = self.cwd
if not Path(cwd).is_absolute(): if not Path(cwd).is_absolute():
# cwd should be relative to the repo_root, but we need the full path # cwd should be relative to the repo_root, but we need the full path
# (repo_root + cwd) in order to resolve the entry point # (repo_root + cwd) in order to resolve the entry point
cwd = os.path.normpath((Path(repo_info.script['repo_root']) / self.cwd).as_posix()) cwd = os.path.normpath((Path(repo_info.script["repo_root"]) / self.cwd).as_posix())
if not Path(cwd).is_dir(): if not Path(cwd).is_dir():
# we need to leave it as is, we have no idea, and this is a repo # we need to leave it as is, we have no idea, and this is a repo
cwd = self.cwd cwd = self.cwd
elif not Path(cwd).is_dir(): elif not Path(cwd).is_dir():
# we were passed an absolute dir and it does not exist # we were passed an absolute dir and it does not exist
raise ValueError("Working directory \'{}\' could not be found".format(cwd)) raise ValueError("Working directory '{}' could not be found".format(cwd))
if self.module: if self.module:
entry_point = "-m {}".format(self.module) entry_point = "-m {}".format(self.module)
elif stand_alone_script_outside_repo: elif stand_alone_script_outside_repo:
# this should be relative and the temp file we generated # this should be relative and the temp file we generated
entry_point = repo_info.script['entry_point'] entry_point = repo_info.script["entry_point"]
else: else:
entry_point = os.path.normpath( entry_point = os.path.normpath(
Path(repo_info.script['repo_root']) / Path(repo_info.script["repo_root"])
repo_info.script['working_dir'] / repo_info.script['entry_point'] / repo_info.script["working_dir"]
/ repo_info.script["entry_point"]
) )
# resolve entry_point relative to the current working directory # resolve entry_point relative to the current working directory
if Path(cwd).is_absolute(): if Path(cwd).is_absolute():
entry_point = Path(entry_point).relative_to(cwd).as_posix() entry_point = Path(entry_point).relative_to(cwd).as_posix()
else: else:
entry_point = repo_info.script['entry_point'] entry_point = repo_info.script["entry_point"]
# restore cwd - make it relative to the repo_root again # restore cwd - make it relative to the repo_root again
if Path(cwd).is_absolute(): if Path(cwd).is_absolute():
# now cwd is relative again # now cwd is relative again
cwd = Path(cwd).relative_to(repo_info.script['repo_root']).as_posix() cwd = Path(cwd).relative_to(repo_info.script["repo_root"]).as_posix()
# make sure we always have / (never \\) # make sure we always have / (never \\)
if platform == "win32": if platform == "win32":
entry_point = entry_point.replace('\\', '/') if entry_point else "" entry_point = entry_point.replace("\\", "/") if entry_point else ""
cwd = cwd.replace('\\', '/') if cwd else "" cwd = cwd.replace("\\", "/") if cwd else ""
task_state['script']['entry_point'] = entry_point or "" task_state["script"]["entry_point"] = entry_point or ""
task_state['script']['working_dir'] = cwd or "." task_state["script"]["working_dir"] = cwd or "."
elif self.repo: elif self.repo:
cwd = '/'.join([p for p in (self.cwd or '.').split('/') if p and p != '.']) cwd = "/".join([p for p in (self.cwd or ".").split("/") if p and p != "."])
# normalize backslashes and remove first one # normalize backslashes and remove first one
if self.module: if self.module:
entry_point = "-m {}".format(self.module) entry_point = "-m {}".format(self.module)
else: else:
entry_point = '/'.join([p for p in self.script.split('/') if p and p != '.']) entry_point = "/".join([p for p in self.script.split("/") if p and p != "."])
if cwd and entry_point.startswith(cwd + '/'): if cwd and entry_point.startswith(cwd + "/"):
entry_point = entry_point[len(cwd) + 1:] entry_point = entry_point[len(cwd) + 1 :]
task_state['script']['repository'] = self.repo task_state["script"]["repository"] = self.repo
task_state['script']['version_num'] = self.commit or None task_state["script"]["version_num"] = self.commit or None
task_state['script']['branch'] = self.branch or None task_state["script"]["branch"] = self.branch or None
task_state['script']['diff'] = '' task_state["script"]["diff"] = ""
task_state['script']['working_dir'] = cwd or '.' task_state["script"]["working_dir"] = cwd or "."
task_state['script']['entry_point'] = entry_point or "" task_state["script"]["entry_point"] = entry_point or ""
if self.script and Path(self.script).is_file() and ( if (
self.force_single_script_file or Path(self.script).is_absolute()): self.script
and Path(self.script).is_file()
and (self.force_single_script_file or Path(self.script).is_absolute())
):
self.force_single_script_file = True self.force_single_script_file = True
create_requirements = self.packages is True create_requirements = self.packages is True
repo_info, requirements = ScriptInfo.get( repo_info, requirements = ScriptInfo.get(
@ -349,29 +383,37 @@ class CreateAndPopulate(object):
detailed_req_report=False, detailed_req_report=False,
force_single_script=True, force_single_script=True,
) )
task_state['script']['binary'] = self.binary or ('/bin/bash' if ( task_state["script"]["binary"] = self.binary or (
(repo_info.script['entry_point'] or '').lower().strip().endswith('.sh') and "/bin/bash"
not (repo_info.script['entry_point'] or '').lower().strip().startswith('-m ')) \ if (
else repo_info.script['binary']) (repo_info.script["entry_point"] or "").lower().strip().endswith(".sh")
task_state['script']['diff'] = repo_info.script['diff'] or '' and not (repo_info.script["entry_point"] or "").lower().strip().startswith("-m ")
task_state['script']['entry_point'] = repo_info.script['entry_point'] )
else repo_info.script["binary"]
)
task_state["script"]["diff"] = repo_info.script["diff"] or ""
task_state["script"]["entry_point"] = repo_info.script["entry_point"]
if create_requirements: if create_requirements:
task_state['script']['requirements'] = repo_info.script.get('requirements') or {} task_state["script"]["requirements"] = repo_info.script.get("requirements") or {}
else: else:
if self.binary: if self.binary:
task_state["script"]["binary"] = self.binary task_state["script"]["binary"] = self.binary
elif entry_point and entry_point.lower().strip().endswith(".sh") and not \ elif (
entry_point.lower().strip().startswith("-m"): entry_point
and entry_point.lower().strip().endswith(".sh")
and not entry_point.lower().strip().startswith("-m")
):
task_state["script"]["binary"] = "/bin/bash" task_state["script"]["binary"] = "/bin/bash"
else: else:
# standalone task # standalone task
task_state['script']['entry_point'] = self.script if self.script else \ task_state["script"]["entry_point"] = (
("-m {}".format(self.module) if self.module else "") self.script if self.script else ("-m {}".format(self.module) if self.module else "")
task_state['script']['working_dir'] = '.' )
task_state["script"]["working_dir"] = "."
# update requirements # update requirements
reqs = [] reqs = []
if self.requirements_file: if self.requirements_file:
with open(self.requirements_file.as_posix(), 'rt') as f: with open(self.requirements_file.as_posix(), "rt") as f:
reqs = [line.strip() for line in f.readlines()] reqs = [line.strip() for line in f.readlines()]
if self.packages and self.packages is not True: if self.packages and self.packages is not True:
reqs += self.packages reqs += self.packages
@ -379,66 +421,76 @@ class CreateAndPopulate(object):
# make sure we have clearml. # make sure we have clearml.
clearml_found = False clearml_found = False
for line in reqs: for line in reqs:
if line.strip().startswith('#'): if line.strip().startswith("#"):
continue continue
package = reduce(lambda a, b: a.split(b)[0], "#;@=~<>[", line).strip() package = reduce(lambda a, b: a.split(b)[0], "#;@=~<>[", line).strip()
if package == 'clearml': if package == "clearml":
clearml_found = True clearml_found = True
break break
if not clearml_found: if not clearml_found:
reqs.append('clearml') reqs.append("clearml")
task_state['script']['requirements'] = {'pip': '\n'.join(reqs)} task_state["script"]["requirements"] = {"pip": "\n".join(reqs)}
elif not self.repo and repo_info and not repo_info.script.get('requirements'): elif not self.repo and repo_info and not repo_info.script.get("requirements"):
# we are in local mode, make sure we have "requirements.txt" it is a must # we are in local mode, make sure we have "requirements.txt" it is a must
reqs_txt_file = Path(repo_info.script['repo_root']) / "requirements.txt" reqs_txt_file = Path(repo_info.script["repo_root"]) / "requirements.txt"
poetry_toml_file = Path(repo_info.script['repo_root']) / "pyproject.toml" poetry_toml_file = Path(repo_info.script["repo_root"]) / "pyproject.toml"
if self.raise_on_missing_entries and not reqs_txt_file.is_file() and not poetry_toml_file.is_file(): if self.raise_on_missing_entries and not reqs_txt_file.is_file() and not poetry_toml_file.is_file():
raise ValueError( raise ValueError(
"requirements.txt not found [{}] " "requirements.txt not found [{}] "
"Use --requirements or --packages".format(reqs_txt_file.as_posix())) "Use --requirements or --packages".format(reqs_txt_file.as_posix())
)
if self.add_task_init_call: if self.add_task_init_call:
script_entry = ('/' + task_state['script'].get('working_dir', '.') script_entry = (
+ '/' + task_state['script']['entry_point']) "/" + task_state["script"].get("working_dir", ".") + "/" + task_state["script"]["entry_point"]
)
if platform == "win32": if platform == "win32":
script_entry = os.path.normpath(script_entry).replace('\\', '/') script_entry = os.path.normpath(script_entry).replace("\\", "/")
else: else:
script_entry = os.path.abspath(script_entry) script_entry = os.path.abspath(script_entry)
idx_a = 0 idx_a = 0
lines = None lines = None
# find the right entry for the patch if we have a local file (basically after __future__ # find the right entry for the patch if we have a local file (basically after __future__
if (local_entry_file and not stand_alone_script_outside_repo and not self.module and if (
str(local_entry_file).lower().endswith(".py")): local_entry_file
with open(local_entry_file, 'rt') as f: and not stand_alone_script_outside_repo
and not self.module
and str(local_entry_file).lower().endswith(".py")
):
with open(local_entry_file, "rt") as f:
lines = f.readlines() lines = f.readlines()
future_found = self._locate_future_import(lines) future_found = self._locate_future_import(lines)
if future_found >= 0: if future_found >= 0:
idx_a = future_found + 1 idx_a = future_found + 1
task_init_patch = '' task_init_patch = ""
if ((self.repo or task_state.get('script', {}).get('repository')) and if (
not self.force_single_script_file and not stand_alone_script_outside_repo): (self.repo or task_state.get("script", {}).get("repository"))
and not self.force_single_script_file
and not stand_alone_script_outside_repo
):
# if we do not have requirements, add clearml to the requirements.txt # if we do not have requirements, add clearml to the requirements.txt
if not reqs: if not reqs:
task_init_patch += \ task_init_patch += (
"diff --git a/requirements.txt b/requirements.txt\n" \ "diff --git a/requirements.txt b/requirements.txt\n"
"--- a/requirements.txt\n" \ "--- a/requirements.txt\n"
"+++ b/requirements.txt\n" \ "+++ b/requirements.txt\n"
"@@ -0,0 +1,1 @@\n" \ "@@ -0,0 +1,1 @@\n"
"+clearml\n" "+clearml\n"
)
# Add Task.init call # Add Task.init call
if not self.module and script_entry and str(script_entry).lower().endswith(".py"): if not self.module and script_entry and str(script_entry).lower().endswith(".py"):
task_init_patch += \ task_init_patch += (
"diff --git a{script_entry} b{script_entry}\n" \ "diff --git a{script_entry} b{script_entry}\n"
"--- a{script_entry}\n" \ "--- a{script_entry}\n"
"+++ b{script_entry}\n" \ "+++ b{script_entry}\n"
"@@ -{idx_a},0 +{idx_b},4 @@\n" \ "@@ -{idx_a},0 +{idx_b},4 @@\n"
"+try: from allegroai import Task\n" \ "+try: from allegroai import Task\n"
"+except ImportError: from clearml import Task\n" \ "+except ImportError: from clearml import Task\n"
"+(__name__ != \"__main__\") or Task.init()\n" \ '+(__name__ != "__main__") or Task.init()\n'
"+\n".format( "+\n".format(script_entry=script_entry, idx_a=idx_a, idx_b=idx_a + 1)
script_entry=script_entry, idx_a=idx_a, idx_b=idx_a + 1) )
elif self.module: elif self.module:
# if we are here, do nothing # if we are here, do nothing
pass pass
@ -449,57 +501,62 @@ class CreateAndPopulate(object):
"except ImportError: from clearml import Task\n", "except ImportError: from clearml import Task\n",
'(__name__ != "__main__") or Task.init()\n\n', '(__name__ != "__main__") or Task.init()\n\n',
] ]
task_state['script']['diff'] = ''.join(lines[:idx_a] + init_lines + lines[idx_a:]) task_state["script"]["diff"] = "".join(lines[:idx_a] + init_lines + lines[idx_a:])
# no need to add anything, we patched it. # no need to add anything, we patched it.
task_init_patch = "" task_init_patch = ""
elif str(script_entry or "").lower().endswith(".py"): elif str(script_entry or "").lower().endswith(".py"):
# Add Task.init call # Add Task.init call
# if we are here it means we do not have a git diff, but a single script file # if we are here it means we do not have a git diff, but a single script file
task_init_patch += \ task_init_patch += (
"try: from allegroai import Task\n" \ "try: from allegroai import Task\n"
"except ImportError: from clearml import Task\n" \ "except ImportError: from clearml import Task\n"
"(__name__ != \"__main__\") or Task.init()\n\n" '(__name__ != "__main__") or Task.init()\n\n'
task_state['script']['diff'] = task_init_patch + task_state['script'].get('diff', '') )
task_state["script"]["diff"] = task_init_patch + task_state["script"].get("diff", "")
task_init_patch = "" task_init_patch = ""
# make sure we add the diff at the end of the current diff # make sure we add the diff at the end of the current diff
task_state['script']['diff'] = task_state['script'].get('diff', '') task_state["script"]["diff"] = task_state["script"].get("diff", "")
if task_state['script']['diff'] and not task_state['script']['diff'].endswith('\n'): if task_state["script"]["diff"] and not task_state["script"]["diff"].endswith("\n"):
task_state['script']['diff'] += '\n' task_state["script"]["diff"] += "\n"
task_state['script']['diff'] += task_init_patch task_state["script"]["diff"] += task_init_patch
# set base docker image if provided # set base docker image if provided
if self.docker: if self.docker:
if dry_run: if dry_run:
task_state['container'] = dict( task_state["container"] = dict(
image=self.docker.get('image') or '', image=self.docker.get("image") or "",
arguments=self.docker.get('args') or '', arguments=self.docker.get("args") or "",
setup_shell_script=self.docker.get('bash_script') or '', setup_shell_script=self.docker.get("bash_script") or "",
) )
else: else:
task.set_base_docker( task.set_base_docker(
docker_image=self.docker.get('image'), docker_image=self.docker.get("image"),
docker_arguments=self.docker.get('args'), docker_arguments=self.docker.get("args"),
docker_setup_bash_script=self.docker.get('bash_script'), docker_setup_bash_script=self.docker.get("bash_script"),
) )
if self.verbose: if self.verbose:
if task_state['script']['repository']: if task_state["script"]["repository"]:
repo_details = {k: v for k, v in task_state['script'].items() repo_details = {
if v and k not in ('diff', 'requirements', 'binary')} k: v for k, v in task_state["script"].items() if v and k not in ("diff", "requirements", "binary")
print('Repository Detected\n{}'.format(json.dumps(repo_details, indent=2))) }
print("Repository Detected\n{}".format(json.dumps(repo_details, indent=2)))
else: else:
print('Standalone script detected\n Script: {}'.format(self.script)) print("Standalone script detected\n Script: {}".format(self.script))
if task_state['script'].get('requirements') and \ if task_state["script"].get("requirements") and task_state["script"]["requirements"].get("pip"):
task_state['script']['requirements'].get('pip'): print(
print('Requirements:{}{}'.format( "Requirements:{}{}".format(
'\n Using requirements.txt: {}'.format( "\n Using requirements.txt: {}".format(self.requirements_file.as_posix())
self.requirements_file.as_posix()) if self.requirements_file else '', if self.requirements_file
'\n {}Packages: {}'.format('Additional ' if self.requirements_file else '', self.packages) else "",
if self.packages else '' "\n {}Packages: {}".format("Additional " if self.requirements_file else "", self.packages)
)) if self.packages
else "",
)
)
if self.docker: if self.docker:
print('Base docker image: {}'.format(self.docker)) print("Base docker image: {}".format(self.docker))
if dry_run: if dry_run:
return task_state return task_state
@ -538,18 +595,17 @@ class CreateAndPopulate(object):
args_list.append(a) args_list.append(a)
continue continue
try: try:
parts = a.split('=', 1) parts = a.split("=", 1)
assert len(parts) == 2 assert len(parts) == 2
args_list.append(parts) args_list.append(parts)
except Exception: except Exception:
raise ValueError( raise ValueError("Failed parsing argument '{}', arguments must be in '<key>=<value>' format")
"Failed parsing argument \'{}\', arguments must be in \'<key>=<value>\' format")
if not self.task: if not self.task:
return return
task_params = self.task.get_parameters() task_params = self.task.get_parameters()
args_list = {'Args/{}'.format(k): v for k, v in args_list} args_list = {"Args/{}".format(k): v for k, v in args_list}
task_params.update(args_list) task_params.update(args_list)
self.task.set_parameters(task_params) self.task.set_parameters(task_params)
@ -569,8 +625,11 @@ class CreateAndPopulate(object):
""" """
# skip over the first two lines, they are ours # skip over the first two lines, they are ours
# then skip over empty or comment lines # then skip over empty or comment lines
lines = [(i, line.split('#', 1)[0].rstrip()) for i, line in enumerate(lines) lines = [
if line.strip('\r\n\t ') and not line.strip().startswith('#')] (i, line.split("#", 1)[0].rstrip())
for i, line in enumerate(lines)
if line.strip("\r\n\t ") and not line.strip().startswith("#")
]
# remove triple quotes ' """ ' # remove triple quotes ' """ '
nested_c = -1 nested_c = -1
@ -597,11 +656,11 @@ class CreateAndPopulate(object):
# check the last import block # check the last import block
i, line = lines[found_index] i, line = lines[found_index]
# wither we have \\ character at the end of the line or the line is indented # wither we have \\ character at the end of the line or the line is indented
parenthesized_lines = '(' in line and ')' not in line parenthesized_lines = "(" in line and ")" not in line
while line.endswith('\\') or parenthesized_lines: while line.endswith("\\") or parenthesized_lines:
found_index += 1 found_index += 1
i, line = lines[found_index] i, line = lines[found_index]
if ')' in line: if ")" in line:
break break
else: else:
@ -671,33 +730,33 @@ if __name__ == '__main__':
@classmethod @classmethod
def create_task_from_function( def create_task_from_function(
cls, cls,
a_function, # type: Callable a_function, # type: Callable
function_kwargs=None, # type: Optional[Dict[str, Any]] function_kwargs=None, # type: Optional[Dict[str, Any]]
function_input_artifacts=None, # type: Optional[Dict[str, str]] function_input_artifacts=None, # type: Optional[Dict[str, str]]
function_return=None, # type: Optional[List[str]] function_return=None, # type: Optional[List[str]]
project_name=None, # type: Optional[str] project_name=None, # type: Optional[str]
task_name=None, # type: Optional[str] task_name=None, # type: Optional[str]
task_type=None, # type: Optional[str] task_type=None, # type: Optional[str]
auto_connect_frameworks=None, # type: Optional[dict] auto_connect_frameworks=None, # type: Optional[dict]
auto_connect_arg_parser=None, # type: Optional[dict] auto_connect_arg_parser=None, # type: Optional[dict]
repo=None, # type: Optional[str] repo=None, # type: Optional[str]
branch=None, # type: Optional[str] branch=None, # type: Optional[str]
commit=None, # type: Optional[str] commit=None, # type: Optional[str]
packages=None, # type: Optional[Union[str, Sequence[str]]] packages=None, # type: Optional[Union[str, Sequence[str]]]
docker=None, # type: Optional[str] docker=None, # type: Optional[str]
docker_args=None, # type: Optional[str] docker_args=None, # type: Optional[str]
docker_bash_setup_script=None, # type: Optional[str] docker_bash_setup_script=None, # type: Optional[str]
output_uri=None, # type: Optional[str] output_uri=None, # type: Optional[str]
helper_functions=None, # type: Optional[Sequence[Callable]] helper_functions=None, # type: Optional[Sequence[Callable]]
dry_run=False, # type: bool dry_run=False, # type: bool
task_template_header=None, # type: Optional[str] task_template_header=None, # type: Optional[str]
artifact_serialization_function=None, # type: Optional[Callable[[Any], Union[bytes, bytearray]]] artifact_serialization_function=None, # type: Optional[Callable[[Any], Union[bytes, bytearray]]]
artifact_deserialization_function=None, # type: Optional[Callable[[bytes], Any]] artifact_deserialization_function=None, # type: Optional[Callable[[bytes], Any]]
_sanitize_function=None, # type: Optional[Callable[[str], str]] _sanitize_function=None, # type: Optional[Callable[[str], str]]
_sanitize_helper_functions=None, # type: Optional[Callable[[str], str]] _sanitize_helper_functions=None, # type: Optional[Callable[[str], str]]
skip_global_imports=False, # type: bool skip_global_imports=False, # type: bool
working_dir=None # type: Optional[str] working_dir=None, # type: Optional[str]
): ):
# type: (...) -> Optional[Dict, Task] # type: (...) -> Optional[Dict, Task]
""" """
@ -793,8 +852,8 @@ if __name__ == '__main__':
if auto_connect_arg_parser is None: if auto_connect_arg_parser is None:
auto_connect_arg_parser = True auto_connect_arg_parser = True
assert (not auto_connect_frameworks or isinstance(auto_connect_frameworks, (bool, dict))) assert not auto_connect_frameworks or isinstance(auto_connect_frameworks, (bool, dict))
assert (not auto_connect_arg_parser or isinstance(auto_connect_arg_parser, (bool, dict))) assert not auto_connect_arg_parser or isinstance(auto_connect_arg_parser, (bool, dict))
function_source, function_name = CreateFromFunction.__extract_function_information( function_source, function_name = CreateFromFunction.__extract_function_information(
a_function, sanitize_function=_sanitize_function, skip_global_imports=skip_global_imports a_function, sanitize_function=_sanitize_function, skip_global_imports=skip_global_imports
@ -819,9 +878,9 @@ if __name__ == '__main__':
function_input_artifacts = function_input_artifacts or dict() function_input_artifacts = function_input_artifacts or dict()
# verify artifact kwargs: # verify artifact kwargs:
if not all(len(v.split('.', 1)) == 2 for v in function_input_artifacts.values()): if not all(len(v.split(".", 1)) == 2 for v in function_input_artifacts.values()):
raise ValueError( raise ValueError(
'function_input_artifacts={}, it must in the format: ' "function_input_artifacts={}, it must in the format: "
'{{"argument": "task_id.artifact_name"}}'.format(function_input_artifacts) '{{"argument": "task_id.artifact_name"}}'.format(function_input_artifacts)
) )
inspect_args = None inspect_args = None
@ -835,16 +894,20 @@ if __name__ == '__main__':
# adjust the defaults so they match the args (match from the end) # adjust the defaults so they match the args (match from the end)
if inspect_defaults_vals and len(inspect_defaults_vals) != len(inspect_defaults_args): if inspect_defaults_vals and len(inspect_defaults_vals) != len(inspect_defaults_args):
inspect_defaults_args = inspect_defaults_args[-len(inspect_defaults_vals):] inspect_defaults_args = inspect_defaults_args[-len(inspect_defaults_vals) :]
if inspect_defaults_vals and len(inspect_defaults_vals) != len(inspect_defaults_args): if inspect_defaults_vals and len(inspect_defaults_vals) != len(inspect_defaults_args):
getLogger().warning( getLogger().warning(
'Ignoring default argument values: ' "Ignoring default argument values: "
'could not find all default valued for: \'{}\''.format(function_name)) "could not find all default valued for: '{}'".format(function_name)
)
inspect_defaults_vals = [] inspect_defaults_vals = []
function_kwargs = {str(k): v for k, v in zip(inspect_defaults_args, inspect_defaults_vals)} \ function_kwargs = (
if inspect_defaults_vals else {str(k): None for k in inspect_defaults_args} {str(k): v for k, v in zip(inspect_defaults_args, inspect_defaults_vals)}
if inspect_defaults_vals
else {str(k): None for k in inspect_defaults_args}
)
if function_kwargs: if function_kwargs:
if not inspect_args: if not inspect_args:
@ -853,8 +916,10 @@ if __name__ == '__main__':
if inspect_args.annotations: if inspect_args.annotations:
supported_types = _Arguments.get_supported_types() supported_types = _Arguments.get_supported_types()
function_kwargs_types = { function_kwargs_types = {
str(k): str(inspect_args.annotations[k].__name__) for k in inspect_args.annotations str(k): str(inspect_args.annotations[k].__name__)
if inspect_args.annotations[k] in supported_types} for k in inspect_args.annotations
if inspect_args.annotations[k] in supported_types
}
task_template = cls.task_template.format( task_template = cls.task_template.format(
header=task_template_header or cls.default_task_template_header, header=task_template_header or cls.default_task_template_header,
@ -871,11 +936,11 @@ if __name__ == '__main__':
artifact_serialization_function_source=artifact_serialization_function_source, artifact_serialization_function_source=artifact_serialization_function_source,
artifact_serialization_function_name=artifact_serialization_function_name, artifact_serialization_function_name=artifact_serialization_function_name,
artifact_deserialization_function_source=artifact_deserialization_function_source, artifact_deserialization_function_source=artifact_deserialization_function_source,
artifact_deserialization_function_name=artifact_deserialization_function_name artifact_deserialization_function_name=artifact_deserialization_function_name,
) )
temp_dir = repo if repo and os.path.isdir(repo) else None temp_dir = repo if repo and os.path.isdir(repo) else None
with tempfile.NamedTemporaryFile('w', suffix='.py', dir=temp_dir) as temp_file: with tempfile.NamedTemporaryFile("w", suffix=".py", dir=temp_dir) as temp_file:
temp_file.write(task_template) temp_file.write(task_template)
temp_file.flush() temp_file.flush()
@ -899,38 +964,53 @@ if __name__ == '__main__':
docker_bash_setup_script=docker_bash_setup_script, docker_bash_setup_script=docker_bash_setup_script,
output_uri=output_uri, output_uri=output_uri,
add_task_init_call=False, add_task_init_call=False,
working_directory=working_dir working_directory=working_dir,
) )
entry_point = '{}.py'.format(function_name) entry_point = "{}.py".format(function_name)
task = populate.create_task(dry_run=dry_run) task = populate.create_task(dry_run=dry_run)
if dry_run: if dry_run:
task['script']['diff'] = task_template task["script"]["diff"] = task_template
task['script']['entry_point'] = entry_point task["script"]["entry_point"] = entry_point
task['script']['working_dir'] = working_dir or '.' task["script"]["working_dir"] = working_dir or "."
task['hyperparams'] = { task["hyperparams"] = {
cls.kwargs_section: { cls.kwargs_section: {
k: dict(section=cls.kwargs_section, name=k, k: dict(
value=str(v) if v is not None else '', type=function_kwargs_types.get(k, None)) section=cls.kwargs_section,
name=k,
value=str(v) if v is not None else "",
type=function_kwargs_types.get(k, None),
)
for k, v in (function_kwargs or {}).items() for k, v in (function_kwargs or {}).items()
}, },
cls.input_artifact_section: { cls.input_artifact_section: {
k: dict(section=cls.input_artifact_section, name=k, value=str(v) if v is not None else '') k: dict(section=cls.input_artifact_section, name=k, value=str(v) if v is not None else "")
for k, v in (function_input_artifacts or {}).items() for k, v in (function_input_artifacts or {}).items()
} },
} }
else: else:
task.update_task(task_data={ task.update_task(
'script': task.data.script.to_dict().update( task_data={
{'entry_point': entry_point, 'working_dir': '.', 'diff': task_template})}) "script": task.data.script.to_dict().update(
hyper_parameters = {'{}/{}'.format(cls.kwargs_section, k): str(v) for k, v in function_kwargs} \ {"entry_point": entry_point, "working_dir": ".", "diff": task_template}
if function_kwargs else {} )
hyper_parameters.update( }
{'{}/{}'.format(cls.input_artifact_section, k): str(v) for k, v in function_input_artifacts} )
if function_input_artifacts else {} hyper_parameters = (
{"{}/{}".format(cls.kwargs_section, k): str(v) for k, v in function_kwargs}
if function_kwargs
else {}
)
hyper_parameters.update(
{"{}/{}".format(cls.input_artifact_section, k): str(v) for k, v in function_input_artifacts}
if function_input_artifacts
else {}
)
__function_kwargs_types = (
{"{}/{}".format(cls.kwargs_section, k): v for k, v in function_kwargs_types}
if function_kwargs_types
else None
) )
__function_kwargs_types = {'{}/{}'.format(cls.kwargs_section, k): v for k, v in function_kwargs_types} \
if function_kwargs_types else None
task.set_parameters(hyper_parameters, __parameters_types=__function_kwargs_types) task.set_parameters(hyper_parameters, __parameters_types=__function_kwargs_types)
return task return task
@ -940,6 +1020,7 @@ if __name__ == '__main__':
# type: (str) -> str # type: (str) -> str
try: try:
import ast import ast
try: try:
# available in Python3.9+ # available in Python3.9+
from ast import unparse from ast import unparse
@ -950,8 +1031,8 @@ if __name__ == '__main__':
# noinspection PyBroadException # noinspection PyBroadException
try: try:
class TypeHintRemover(ast.NodeTransformer):
class TypeHintRemover(ast.NodeTransformer):
def visit_FunctionDef(self, node): def visit_FunctionDef(self, node):
# remove the return type definition # remove the return type definition
node.returns = None node.returns = None
@ -967,7 +1048,7 @@ if __name__ == '__main__':
# and import statements from 'typing' # and import statements from 'typing'
transformed = TypeHintRemover().visit(parsed_source) transformed = TypeHintRemover().visit(parsed_source)
# convert the AST back to source code # convert the AST back to source code
return unparse(transformed).lstrip('\n') return unparse(transformed).lstrip("\n")
except Exception: except Exception:
# just in case we failed parsing. # just in case we failed parsing.
return function_source return function_source
@ -975,14 +1056,16 @@ if __name__ == '__main__':
@staticmethod @staticmethod
def __extract_imports(func): def __extract_imports(func):
def add_import_guard(import_): def add_import_guard(import_):
return ("try:\n " return (
+ import_.replace("\n", "\n ", import_.count("\n") - 1) "try:\n "
+ "\nexcept Exception as e:\n print('Import error: ' + str(e))\n" + import_.replace("\n", "\n ", import_.count("\n") - 1)
) + "\nexcept Exception as e:\n print('Import error: ' + str(e))\n"
)
# noinspection PyBroadException # noinspection PyBroadException
try: try:
import ast import ast
func_module = inspect.getmodule(func) func_module = inspect.getmodule(func)
source = inspect.getsource(func_module) source = inspect.getsource(func_module)
parsed_source = ast.parse(source) parsed_source = ast.parse(source)
@ -1006,7 +1089,7 @@ if __name__ == '__main__':
imports = [add_import_guard(import_) for import_ in imports] imports = [add_import_guard(import_) for import_ in imports]
return "\n".join(imports) return "\n".join(imports)
except Exception as e: except Exception as e:
getLogger().warning('Could not fetch function imports: {}'.format(e)) getLogger().warning("Could not fetch function imports: {}".format(e))
return "" return ""
@staticmethod @staticmethod
@ -1046,7 +1129,7 @@ if __name__ == '__main__':
result.append(f.name) result.append(f.name)
except Exception as e: except Exception as e:
name = getattr(module, "__name__", module) name = getattr(module, "__name__", module)
getLogger().warning('Could not fetch function declared in {}: {}'.format(name, e)) getLogger().warning("Could not fetch function declared in {}: {}".format(name, e))
return result return result
@staticmethod @staticmethod
@ -1058,12 +1141,11 @@ if __name__ == '__main__':
func_members_dict = dict(inspect.getmembers(original_module, inspect.isfunction)) func_members_dict = dict(inspect.getmembers(original_module, inspect.isfunction))
except Exception as e: except Exception as e:
name = getattr(original_module, "__name__", original_module) name = getattr(original_module, "__name__", original_module)
getLogger().warning('Could not fetch functions from {}: {}'.format(name, e)) getLogger().warning("Could not fetch functions from {}: {}".format(name, e))
func_members_dict = {} func_members_dict = {}
decorated_func = CreateFromFunction._deep_extract_wrapped(func) decorated_func = CreateFromFunction._deep_extract_wrapped(func)
decorated_func_source = CreateFromFunction.__sanitize( decorated_func_source = CreateFromFunction.__sanitize(
inspect.getsource(decorated_func), inspect.getsource(decorated_func), sanitize_function=sanitize_function
sanitize_function=sanitize_function
) )
try: try:
import ast import ast
@ -1083,14 +1165,16 @@ if __name__ == '__main__':
decorator_func = func_members_dict.get(name) decorator_func = func_members_dict.get(name)
if name not in func_members or not decorator_func: if name not in func_members or not decorator_func:
continue continue
decorated_func_source = CreateFromFunction.__get_source_with_decorators( decorated_func_source = (
decorator_func, CreateFromFunction.__get_source_with_decorators(
original_module=original_module, decorator_func, original_module=original_module, sanitize_function=sanitize_function
sanitize_function=sanitize_function )
) + "\n\n" + decorated_func_source + "\n\n"
+ decorated_func_source
)
break break
except Exception as e: except Exception as e:
getLogger().warning('Could not fetch full definition of function {}: {}'.format(func.__name__, e)) getLogger().warning("Could not fetch full definition of function {}: {}".format(func.__name__, e))
return decorated_func_source return decorated_func_source
@staticmethod @staticmethod