Optimize repository query and requirements analysis

This commit is contained in:
allegroai 2019-07-08 23:28:55 +03:00
parent 96abe3ef04
commit 22b18e7338
2 changed files with 36 additions and 21 deletions

View File

@ -86,23 +86,40 @@ class ScriptRequirements(object):
@staticmethod @staticmethod
def create_requirements_txt(reqs): def create_requirements_txt(reqs):
# write requirements.txt # write requirements.txt
# python version header
requirements_txt = '# Python ' + sys.version.replace('\n', ' ').replace('\r', ' ') + '\n' requirements_txt = '# Python ' + sys.version.replace('\n', ' ').replace('\r', ' ') + '\n'
for k, v in reqs.sorted_items():
# requirement summary
requirements_txt += '\n' requirements_txt += '\n'
requirements_txt += ''.join(['# {0}\n'.format(c) for c in v.comments.sorted_items()]) for k, v in reqs.sorted_items():
# requirements_txt += ''.join(['# {0}\n'.format(c) for c in v.comments.sorted_items()])
if k == '-e': if k == '-e':
requirements_txt += '{0} {1}\n'.format(k, v.version) requirements_txt += '{0} {1}\n'.format(k, v.version)
elif v: elif v:
requirements_txt += '{0} {1} {2}\n'.format(k, '==', v.version) requirements_txt += '{0} {1} {2}\n'.format(k, '==', v.version)
else: else:
requirements_txt += '{0}\n'.format(k) requirements_txt += '{0}\n'.format(k)
# requirements details (in comments)
requirements_txt += '\n' + \
'# Detailed import analysis\n' \
'# **************************\n'
for k, v in reqs.sorted_items():
requirements_txt += '\n'
if k == '-e':
requirements_txt += '# IMPORT PACKAGE {0} {1}\n'.format(k, v.version)
else:
requirements_txt += '# IMPORT PACKAGE {0}\n'.format(k)
requirements_txt += ''.join(['# {0}\n'.format(c) for c in v.comments.sorted_items()])
return requirements_txt return requirements_txt
class _JupyterObserver(object): class _JupyterObserver(object):
_thread = None _thread = None
_exit_event = Event() _exit_event = Event()
_sample_frequency = 60. _sample_frequency = 30.
_first_sample_frequency = 3. _first_sample_frequency = 3.
@classmethod @classmethod
@ -228,7 +245,8 @@ class ScriptInfo(object):
break break
notebook_path = cur_notebook['notebook']['path'] notebook_path = cur_notebook['notebook']['path']
entry_point_filename = notebook_path.split(os.path.sep)[-1] # always slash, because this is from uri (so never backslash not even oon windows)
entry_point_filename = notebook_path.split('/')[-1]
# now we should try to find the actual file # now we should try to find the actual file
entry_point = (Path.cwd() / entry_point_filename).absolute() entry_point = (Path.cwd() / entry_point_filename).absolute()
@ -281,7 +299,7 @@ class ScriptInfo(object):
return '' return ''
@classmethod @classmethod
def _get_script_info(cls, filepath, check_uncommitted=True, log=None): def _get_script_info(cls, filepath, check_uncommitted=True, create_requirements=True, log=None):
jupyter_filepath = cls._get_jupyter_notebook_filename() jupyter_filepath = cls._get_jupyter_notebook_filename()
if jupyter_filepath: if jupyter_filepath:
script_path = Path(os.path.normpath(jupyter_filepath)).absolute() script_path = Path(os.path.normpath(jupyter_filepath)).absolute()
@ -319,11 +337,15 @@ class ScriptInfo(object):
repo_root = repo_info.root or script_dir repo_root = repo_info.root or script_dir
working_dir = cls._get_working_dir(repo_root) working_dir = cls._get_working_dir(repo_root)
entry_point = cls._get_entry_point(repo_root, script_path) entry_point = cls._get_entry_point(repo_root, script_path)
diff = cls._get_script_code(script_path.as_posix()) if not plugin or not repo_info.commit else repo_info.diff if check_uncommitted:
diff = cls._get_script_code(script_path.as_posix()) \
if not plugin or not repo_info.commit else repo_info.diff
else:
diff = ''
# if this is not jupyter, get the requirements.txt # if this is not jupyter, get the requirements.txt
requirements = '' requirements = ''
# create requirements if backend supports requirements # create requirements if backend supports requirements
if not jupyter_filepath and Session.api_version > '2.1': if create_requirements and not jupyter_filepath and Session.api_version > '2.1':
script_requirements = ScriptRequirements(Path(repo_root).as_posix()) script_requirements = ScriptRequirements(Path(repo_root).as_posix())
requirements = script_requirements.get_requirements() requirements = script_requirements.get_requirements()
@ -351,11 +373,11 @@ class ScriptInfo(object):
return ScriptInfoResult(script=script_info, warning_messages=messages) return ScriptInfoResult(script=script_info, warning_messages=messages)
@classmethod @classmethod
def get(cls, filepath=sys.argv[0], check_uncommitted=True, log=None): def get(cls, filepath=sys.argv[0], check_uncommitted=True, create_requirements=True, log=None):
try: try:
return cls._get_script_info( return cls._get_script_info(
filepath=filepath, check_uncommitted=check_uncommitted, log=log filepath=filepath, check_uncommitted=check_uncommitted,
) create_requirements=create_requirements, log=log)
except Exception as ex: except Exception as ex:
if log: if log:
log.warning("Failed auto-detecting task repository: {}".format(ex)) log.warning("Failed auto-detecting task repository: {}".format(ex))

View File

@ -304,7 +304,7 @@ class Task(_Task):
def _create_dev_task(cls, default_project_name, default_task_name, default_task_type, reuse_last_task_id): def _create_dev_task(cls, default_project_name, default_task_name, default_task_type, reuse_last_task_id):
if not default_project_name or not default_task_name: if not default_project_name or not default_task_name:
# get project name and task name from repository name and entry_point # get project name and task name from repository name and entry_point
result = ScriptInfo.get() result = ScriptInfo.get(create_requirements=False, check_uncommitted=False)
if result: if result:
if not default_project_name: if not default_project_name:
# noinspection PyBroadException # noinspection PyBroadException
@ -359,6 +359,8 @@ class Task(_Task):
else: else:
# reset the task, so we can update it # reset the task, so we can update it
task.reset(set_started_on_success=False, force=False) task.reset(set_started_on_success=False, force=False)
# set development tags
task.set_tags(['development'])
# clear task parameters, they are not cleared by the Task reset # clear task parameters, they are not cleared by the Task reset
task.set_parameters({}, __update=False) task.set_parameters({}, __update=False)
# clear the comment, it is not cleared on reset # clear the comment, it is not cleared on reset
@ -884,15 +886,6 @@ class Task(_Task):
if not flush_period or flush_period > self._dev_worker.report_period: if not flush_period or flush_period > self._dev_worker.report_period:
logger.set_flush_period(self._dev_worker.report_period) logger.set_flush_period(self._dev_worker.report_period)
# Remove 'development' tag
tags = self.get_tags()
try:
tags.remove('development')
except ValueError:
pass
else:
self.set_tags(tags)
def _at_exit(self): def _at_exit(self):
""" """
Will happen automatically once we exit code, i.e. atexit Will happen automatically once we exit code, i.e. atexit