Integrate pigar into Trains

This commit is contained in:
allegroai 2020-03-01 17:12:28 +02:00
parent 8ee2bd1844
commit 146da439e7
13 changed files with 1078 additions and 95 deletions

View File

@ -11,7 +11,6 @@ jsonschema>=2.6.0
numpy>=1.10
pathlib2>=2.3.0
Pillow>=4.1.1
pigar==0.9.2
plotly>=3.9.0
psutil>=3.4.2
pyparsing>=2.0.3

View File

@ -26,103 +26,18 @@ class ScriptRequirements(object):
def __init__(self, root_folder):
self._root_folder = root_folder
@staticmethod
def get_installed_pkgs_detail(reqs):
"""
HACK: bugfix of the original pigar get_installed_pkgs_detail
Get mapping for import top level name
and install package name with version.
"""
mapping = dict()
for path in sys.path:
if os.path.isdir(path) and path.rstrip('/').endswith(
('site-packages', 'dist-packages')):
new_mapping = reqs._search_path(path)
# BUGFIX:
# override with previous, just like python resolves imports, the first match is the one used.
# unlike the original implementation, where the last one is used.
new_mapping.update(mapping)
mapping = new_mapping
# HACK: prefer tensorflow_gpu over tensorflow
if 'tensorflow_gpu' in new_mapping:
new_mapping['tensorflow'] = new_mapping['tensorflow_gpu']
return mapping
def get_requirements(self):
try:
from pigar import reqs
reqs.project_import_modules = ScriptRequirements._patched_project_import_modules
from pigar.__main__ import GenerateReqs
from pigar.log import logger
logger.setLevel(logging.WARNING)
try:
# first try our version, if we fail revert to the internal implantation
installed_pkgs = self.get_installed_pkgs_detail(reqs)
except Exception:
installed_pkgs = reqs.get_installed_pkgs_detail()
from ....utilities.pigar.reqs import get_installed_pkgs_detail
from ....utilities.pigar.__main__ import GenerateReqs
installed_pkgs = get_installed_pkgs_detail()
gr = GenerateReqs(save_path='', project_path=self._root_folder, installed_pkgs=installed_pkgs,
ignores=['.git', '.hg', '.idea', '__pycache__', '.ipynb_checkpoints'])
reqs, try_imports, guess = gr.extract_reqs()
return self.create_requirements_txt(reqs)
reqs, try_imports, guess, local_pks = gr.extract_reqs(module_callback=ScriptRequirements.add_trains_used_packages)
return self.create_requirements_txt(reqs, local_pks)
except Exception:
return '', ''
@staticmethod
def _patched_project_import_modules(project_path, ignores):
"""
copied form pigar req.project_import_modules
patching, os.getcwd() is incorrectly used
"""
from pigar.modules import ImportedModules
from pigar.reqs import file_import_modules
modules = ImportedModules()
try_imports = set()
local_mods = list()
ignore_paths = collections.defaultdict(set)
if not ignores:
ignore_paths[project_path].add('.git')
else:
for path in ignores:
parent_dir = os.path.dirname(path)
ignore_paths[parent_dir].add(os.path.basename(path))
if os.path.isfile(project_path):
fake_path = Path(project_path).name
with open(project_path, 'rb') as f:
fmodules, try_ipts = file_import_modules(fake_path, f.read())
modules |= fmodules
try_imports |= try_ipts
else:
cur_dir = project_path # os.getcwd()
for dirpath, dirnames, files in os.walk(project_path, followlinks=True):
if dirpath in ignore_paths:
dirnames[:] = [d for d in dirnames
if d not in ignore_paths[dirpath]]
py_files = list()
for fn in files:
# C extension.
if fn.endswith('.so'):
local_mods.append(fn[:-3])
# Normal Python file.
if fn.endswith('.py'):
local_mods.append(fn[:-3])
py_files.append(fn)
if '__init__.py' in files:
local_mods.append(os.path.basename(dirpath))
for file in py_files:
fpath = os.path.join(dirpath, file)
fake_path = fpath.split(cur_dir)[1][1:]
with open(fpath, 'rb') as f:
fmodules, try_ipts = file_import_modules(fake_path, f.read())
modules |= fmodules
try_imports |= try_ipts
return ScriptRequirements.add_trains_used_packages(modules), try_imports, local_mods
@staticmethod
def add_trains_used_packages(modules):
# hack: forcefully insert storage modules if we have them
@ -159,7 +74,7 @@ class ScriptRequirements(object):
return modules
@staticmethod
def create_requirements_txt(reqs):
def create_requirements_txt(reqs, local_pks=None):
# write requirements.txt
try:
conda_requirements = ''
@ -188,6 +103,11 @@ class ScriptRequirements(object):
# python version header
requirements_txt = '# Python ' + sys.version.replace('\n', ' ').replace('\r', ' ') + '\n'
if local_pks:
requirements_txt += '\n# Local modules found - skipping:\n'
for k, v in local_pks.sorted_items():
requirements_txt += '# {0} == {1}\n'.format(k, v.version)
# requirement summary
requirements_txt += '\n'
for k, v in reqs.sorted_items():
@ -203,6 +123,13 @@ class ScriptRequirements(object):
requirements_txt += '\n' + \
'# Detailed import analysis\n' \
'# **************************\n'
if local_pks:
for k, v in local_pks.sorted_items():
requirements_txt += '\n'
requirements_txt += '# IMPORT LOCAL PACKAGE {0}\n'.format(k)
requirements_txt += ''.join(['# {0}\n'.format(c) for c in v.comments.sorted_items()])
for k, v in reqs.sorted_items():
requirements_txt += '\n'
if k == '-e':
@ -262,9 +189,9 @@ class _JupyterObserver(object):
# load pigar
# noinspection PyBroadException
try:
from pigar.reqs import get_installed_pkgs_detail, file_import_modules
from pigar.modules import ReqsModules
from pigar.log import logger
from ....utilities.pigar.reqs import get_installed_pkgs_detail, file_import_modules
from ....utilities.pigar.modules import ReqsModules
from ....utilities.pigar.log import logger
logger.setLevel(logging.WARNING)
except Exception:
file_import_modules = None

View File

@ -0,0 +1 @@
# -*- coding: utf-8 -*-

View File

@ -0,0 +1,133 @@
# -*- coding: utf-8 -*-
from __future__ import print_function, division, absolute_import
import os
import codecs
from .reqs import project_import_modules, is_std_or_local_lib
from .utils import lines_diff
from .log import logger
from .modules import ReqsModules
class GenerateReqs(object):
def __init__(self, save_path, project_path, ignores,
installed_pkgs, comparison_operator='=='):
self._save_path = save_path
self._project_path = project_path
self._ignores = ignores
self._installed_pkgs = installed_pkgs
self._maybe_local_mods = set()
self._local_mods = dict()
self._comparison_operator = comparison_operator
def extract_reqs(self, module_callback=None):
"""Extract requirements from project."""
reqs = ReqsModules()
guess = ReqsModules()
local = ReqsModules()
modules, try_imports, local_mods = project_import_modules(
self._project_path, self._ignores)
if module_callback:
modules = module_callback(modules)
app_name = os.path.basename(self._project_path)
if app_name in local_mods:
local_mods.remove(app_name)
# Filtering modules
candidates = self._filter_modules(modules, local_mods)
logger.info('Check module in local environment.')
for name in candidates:
logger.info('Checking module: %s', name)
if name in self._installed_pkgs:
pkg_name, version = self._installed_pkgs[name]
reqs.add(pkg_name, version, modules[name])
else:
guess.add(name, 0, modules[name])
# add local modules, so we know what is used but not installed.
for name in self._local_mods:
if name in modules:
relpath = os.path.relpath(self._local_mods[name], self._project_path)
if not relpath.startswith('.'):
relpath = '.' + os.path.sep + relpath
local.add(name, relpath, modules[name])
return reqs, try_imports, guess, local
def _write_reqs(self, reqs):
print('Writing requirements to "{0}"'.format(
self._save_path))
with open(self._save_path, 'w+') as f:
f.write('# Requirements automatically generated by pigar.\n'
'# https://github.com/damnever/pigar\n')
for k, v in reqs.sorted_items():
f.write('\n')
f.write(''.join(['# {0}\n'.format(c)
for c in v.comments.sorted_items()]))
if k == '-e':
f.write('{0} {1}\n'.format(k, v.version))
elif v:
f.write('{0} {1} {2}\n'.format(
k, self._comparison_operator, v.version))
else:
f.write('{0}\n'.format(k))
def _best_matchs(self, name, pkgs):
# If imported name equals to package name.
if name in pkgs:
return [pkgs[pkgs.index(name)]]
# If not, return all possible packages.
return pkgs
def _filter_modules(self, modules, local_mods):
candidates = set()
logger.info('Filtering modules ...')
for module in modules:
logger.info('Checking module: %s', module)
if not module or module.startswith('.'):
continue
if module in local_mods:
self._maybe_local_mods.add(module)
module_std_local = is_std_or_local_lib(module)
if module_std_local is True:
continue
if isinstance(module_std_local, str):
self._local_mods[module] = module_std_local
continue
candidates.add(module)
return candidates
def _invalid_reqs(self, reqs):
for name, detail in reqs.sorted_items():
print(
' {0} referenced from:\n {1}'.format(
name,
'\n '.join(detail.comments.sorted_items())
)
)
def _save_old_reqs(self):
if os.path.isfile(self._save_path):
with codecs.open(self._save_path, 'rb', 'utf-8') as f:
self._old_reqs = f.readlines()
def _reqs_diff(self):
if not hasattr(self, '_old_reqs'):
return
with codecs.open(self._save_path, 'rb', 'utf-8') as f:
new_reqs = f.readlines()
is_diff, diffs = lines_diff(self._old_reqs, new_reqs)
msg = 'Requirements file has been covered, '
if is_diff:
msg += 'there is the difference:'
print('{0}\n{1}'.format(msg, ''.join(diffs)), end='')
else:
msg += 'no difference.'
print(msg)

View File

@ -0,0 +1,12 @@
# -*- coding: utf-8 -*-
from __future__ import print_function, division, absolute_import
from ..utils import PY32
if PY32:
from .thread_extractor import ThreadExtractor as Extractor
else:
from .gevent_extractor import GeventExtractor as Extractor

View File

@ -0,0 +1,36 @@
# -*- coding: utf-8 -*-
from __future__ import print_function, division, absolute_import
import multiprocessing
class BaseExtractor(object):
def __init__(self, names, max_workers=None):
self._names = names
self._max_workers = max_workers or (multiprocessing.cpu_count() * 4)
def run(self, job):
try:
self.extract(job)
self.wait_complete()
except KeyboardInterrupt:
print('** Shutting down ...')
self.shutdown()
else:
print('^.^ Extracting all packages done!')
finally:
self.final()
def extract(self, job):
raise NotImplemented
def wait_complete(self):
raise NotImplemented
def shutdown(self):
raise NotImplemented
def final(self):
pass

View File

@ -0,0 +1,50 @@
# -*- coding: utf-8 -*-
from __future__ import print_function, division, absolute_import
import sys
import greenlet
from gevent.pool import Pool
from .extractor import BaseExtractor
from ..log import logger
class GeventExtractor(BaseExtractor):
def __init__(self, names, max_workers=222):
super(self.__class__, self).__init__(names, max_workers)
self._pool = Pool(self._max_workers)
self._exited_greenlets = 0
def extract(self, job):
job = self._job_wrapper(job)
for name in self._names:
if self._pool.full():
self._pool.wait_available()
self._pool.spawn(job, name)
def _job_wrapper(self, job):
def _job(name):
result = None
try:
result = job(name)
except greenlet.GreenletExit:
self._exited_greenlets += 1
except Exception:
e = sys.exc_info()[1]
logger.error('Extracting "{0}", got: {1}'.format(name, e))
return result
return _job
def wait_complete(self):
self._pool.join()
def shutdown(self):
self._pool.kill(block=True)
def final(self):
count = self._exited_greenlets
if count != 0:
print('** {0} running job exited.'.format(count))

View File

@ -0,0 +1,45 @@
# -*- coding: utf-8 -*-
from __future__ import print_function, division, absolute_import
import concurrent.futures
from .extractor import BaseExtractor
from ..log import logger
class ThreadExtractor(BaseExtractor):
"""Extractor use thread pool execute tasks.
Can be used to extract /simple/<pkg_name> or /pypi/<pkg_name>/json.
FIXME: can not deliver SIG_INT to threads in Python 2.
"""
def __init__(self, names, max_workers=None):
super(self.__class__, self).__init__(names, max_workers)
self._futures = dict()
def extract(self, job):
"""Extract url by package name."""
with concurrent.futures.ThreadPoolExecutor(
max_workers=self._max_workers) as executor:
for name in self._names:
self._futures[executor.submit(job, name)] = name
def wait_complete(self):
"""Wait for futures complete done."""
for future in concurrent.futures.as_completed(self._futures.keys()):
try:
error = future.exception()
except concurrent.futures.CancelledError:
break
name = self._futures[future]
if error is not None:
err_msg = 'Extracting "{0}", got: {1}'.format(name, error)
logger.error(err_msg)
def shutdown(self):
for future in self._futures:
future.cancel()

View File

@ -0,0 +1,9 @@
# -*- coding: utf-8 -*-
from __future__ import print_function, division, absolute_import
import logging.handlers
logger = logging.getLogger('pigar')
logger.setLevel(logging.WARNING)

View File

@ -0,0 +1,111 @@
# -*- coding: utf-8 -*-
from __future__ import print_function, division, absolute_import
import collections
# FIXME: Just a workaround, not a radical cure..
_special_cases = {
"dogpile.cache": "dogpile.cache",
"dogpile.core": "dogpile.core",
"ruamel.yaml": "ruamel.yaml",
"ruamel.ordereddict": "ruamel.ordereddict",
}
class Modules(dict):
"""Modules object will be used to store modules information."""
def __init__(self):
super(Modules, self).__init__()
class ImportedModules(Modules):
def __init__(self):
super(ImportedModules, self).__init__()
def add(self, name, file, lineno):
if name is None:
return
names = list()
special_name = '.'.join(name.split('.')[:2])
# Flask extension.
if name.startswith('flask.ext.'):
names.append('flask')
names.append('flask_' + name.split('.')[2])
# Special cases..
elif special_name in _special_cases:
names.append(_special_cases[special_name])
# Other.
elif '.' in name and not name.startswith('.'):
names.append(name.split('.')[0])
else:
names.append(name)
for nm in names:
if nm not in self:
self[nm] = _Locations()
self[nm].add(file, lineno)
def __or__(self, obj):
for name, locations in obj.items():
for file, linenos in locations.items():
for lineno in linenos:
self.add(name, file, lineno)
return self
class ReqsModules(Modules):
_Detail = collections.namedtuple('Detail', ['version', 'comments'])
def __init__(self):
super(ReqsModules, self).__init__()
self._sorted = None
def add(self, package, version, locations):
if package in self:
self[package].comments.extend(locations)
else:
self[package] = self._Detail(version, locations)
def sorted_items(self):
if self._sorted is None:
self._sorted = sorted(self.items())
return self._sorted
def remove(self, *names):
for name in names:
if name in self:
self.pop(name)
self._sorted = None
class _Locations(dict):
"""_Locations store code locations(file, linenos)."""
def __init__(self):
super(_Locations, self).__init__()
self._sorted = None
def add(self, file, lineno):
if file in self and lineno not in self[file]:
self[file].append(lineno)
else:
self[file] = [lineno]
def extend(self, obj):
for file, linenos in obj.items():
for lineno in linenos:
self.add(file, lineno)
def sorted_items(self):
if self._sorted is None:
self._sorted = [
'{0}: {1}'.format(f, ','.join([str(n) for n in sorted(ls)]))
for f, ls in sorted(self.items())
]
return self._sorted

View File

@ -0,0 +1,398 @@
# -*- coding: utf-8 -*-
from __future__ import print_function, division, absolute_import
import os
import sys
import fnmatch
import importlib
import imp
import ast
import doctest
import collections
import functools
from pathlib2 import Path
try:
from types import FileType # py2
except ImportError:
from io import IOBase as FileType # py3
from .log import logger
from .utils import parse_git_config
from .modules import ImportedModules
def project_import_modules(project_path, ignores):
"""
copied form pigar req.project_import_modules patching, os.getcwd() is incorrectly used
"""
modules = ImportedModules()
try_imports = set()
local_mods = list()
ignore_paths = collections.defaultdict(set)
if not ignores:
ignore_paths[project_path].add('.git')
else:
for path in ignores:
parent_dir = os.path.dirname(path)
ignore_paths[parent_dir].add(os.path.basename(path))
if os.path.isfile(project_path):
fake_path = Path(project_path).name
with open(project_path, 'rb') as f:
fmodules, try_ipts = file_import_modules(fake_path, f.read())
modules |= fmodules
try_imports |= try_ipts
else:
cur_dir = project_path # os.getcwd()
for dirpath, dirnames, files in os.walk(project_path, followlinks=True):
if dirpath in ignore_paths:
dirnames[:] = [d for d in dirnames
if d not in ignore_paths[dirpath]]
py_files = list()
for fn in files:
# C extension.
if fn.endswith('.so'):
local_mods.append(fn[:-3])
# Normal Python file.
if fn.endswith('.py'):
local_mods.append(fn[:-3])
py_files.append(fn)
if '__init__.py' in files:
local_mods.append(os.path.basename(dirpath))
for file in py_files:
fpath = os.path.join(dirpath, file)
fake_path = fpath.split(cur_dir)[1][1:]
with open(fpath, 'rb') as f:
fmodules, try_ipts = file_import_modules(fake_path, f.read())
modules |= fmodules
try_imports |= try_ipts
return modules, try_imports, local_mods
def file_import_modules(fpath, fdata):
"""Get single file all imported modules."""
modules = ImportedModules()
str_codes = collections.deque([(fdata, 1)])
try_imports = set()
while str_codes:
str_code, lineno = str_codes.popleft()
ic = ImportChecker(fpath, lineno)
try:
parsed = ast.parse(str_code)
ic.visit(parsed)
# Ignore SyntaxError in Python code.
except SyntaxError:
pass
modules |= ic.modules
str_codes.extend(ic.str_codes)
try_imports |= ic.try_imports
del ic
return modules, try_imports
class ImportChecker(object):
def __init__(self, fpath, lineno):
self._fpath = fpath
self._lineno = lineno - 1
self._modules = ImportedModules()
self._str_codes = collections.deque()
self._try_imports = set()
def visit_Import(self, node, try_=False):
"""As we know: `import a [as b]`."""
lineno = node.lineno + self._lineno
for alias in node.names:
self._modules.add(alias.name, self._fpath, lineno)
if try_:
self._try_imports.add(alias.name)
def visit_ImportFrom(self, node, try_=False):
"""
As we know: `from a import b [as c]`. If node.level is not 0,
import statement like this `from .a import b`.
"""
mod_name = node.module
level = node.level
if mod_name is None:
level -= 1
mod_name = ''
for alias in node.names:
name = level*'.' + mod_name + '.' + alias.name
self._modules.add(name, self._fpath, node.lineno + self._lineno)
if try_:
self._try_imports.add(name)
def visit_TryExcept(self, node):
"""
If modules which imported by `try except` and not found,
maybe them come from other Python version.
"""
for ipt in node.body:
if ipt.__class__.__name__.startswith('Import'):
method = 'visit_' + ipt.__class__.__name__
getattr(self, method)(ipt, True)
for handler in node.handlers:
for ipt in handler.body:
if ipt.__class__.__name__.startswith('Import'):
method = 'visit_' + ipt.__class__.__name__
getattr(self, method)(ipt, True)
# For Python 3.3+
visit_Try = visit_TryExcept
def visit_Exec(self, node):
"""
Check `expression` of `exec(expression[, globals[, locals]])`.
**Just available in python 2.**
"""
if hasattr(node.body, 's'):
self._str_codes.append((node.body.s, node.lineno + self._lineno))
# PR#13: https://github.com/damnever/pigar/pull/13
# Sometimes exec statement may be called with tuple in Py2.7.6
elif hasattr(node.body, 'elts') and len(node.body.elts) >= 1:
self._str_codes.append(
(node.body.elts[0].s, node.lineno + self._lineno))
def visit_Expr(self, node):
"""
Check `expression` of `eval(expression[, globals[, locals]])`.
Check `expression` of `exec(expression[, globals[, locals]])`
in python 3.
Check `name` of `__import__(name[, globals[, locals[,
fromlist[, level]]]])`.
Check `name` or `package` of `importlib.import_module(name,
package=None)`.
"""
# Built-in functions
value = node.value
if isinstance(value, ast.Call):
if hasattr(value.func, 'id'):
if (value.func.id == 'eval' and
hasattr(node.value.args[0], 's')):
self._str_codes.append(
(node.value.args[0].s, node.lineno + self._lineno))
# **`exec` function in Python 3.**
elif (value.func.id == 'exec' and
hasattr(node.value.args[0], 's')):
self._str_codes.append(
(node.value.args[0].s, node.lineno + self._lineno))
# `__import__` function.
elif (value.func.id == '__import__' and
len(node.value.args) > 0 and
hasattr(node.value.args[0], 's')):
self._modules.add(node.value.args[0].s, self._fpath,
node.lineno + self._lineno)
# `import_module` function.
elif getattr(value.func, 'attr', '') == 'import_module':
module = getattr(value.func, 'value', None)
if (module is not None and
getattr(module, 'id', '') == 'importlib'):
args = node.value.args
arg_len = len(args)
if arg_len > 0 and hasattr(args[0], 's'):
name = args[0].s
if not name.startswith('.'):
self._modules.add(name, self._fpath,
node.lineno + self._lineno)
elif arg_len == 2 and hasattr(args[1], 's'):
self._modules.add(args[1].s, self._fpath,
node.lineno + self._lineno)
def visit_FunctionDef(self, node):
"""
Check docstring of function, if docstring is used for doctest.
"""
docstring = self._parse_docstring(node)
if docstring:
self._str_codes.append((docstring, node.lineno + self._lineno + 2))
def visit_ClassDef(self, node):
"""
Check docstring of class, if docstring is used for doctest.
"""
docstring = self._parse_docstring(node)
if docstring:
self._str_codes.append((docstring, node.lineno + self._lineno + 2))
def visit(self, node):
"""Visit a node, no recursively."""
for node in ast.walk(node):
method = 'visit_' + node.__class__.__name__
getattr(self, method, lambda x: x)(node)
@staticmethod
def _parse_docstring(node):
"""Extract code from docstring."""
docstring = ast.get_docstring(node)
if docstring:
parser = doctest.DocTestParser()
try:
dt = parser.get_doctest(docstring, {}, None, None, None)
except ValueError:
# >>> 'abc'
pass
else:
examples = dt.examples
return '\n'.join([example.source for example in examples])
return None
@property
def modules(self):
return self._modules
@property
def str_codes(self):
return self._str_codes
@property
def try_imports(self):
return set((name.split('.')[0] if name and '.' in name else name)
for name in self._try_imports)
def _checked_cache(func):
checked = dict()
@functools.wraps(func)
def _wrapper(name):
if name not in checked:
checked[name] = func(name)
return checked[name]
return _wrapper
@_checked_cache
def is_std_or_local_lib(name):
"""Check whether it is stdlib module.
True if std lib
False if installed package
str if local library
"""
exist = True
module_info = ('', '', '')
try:
module_info = imp.find_module(name)
except ImportError:
try:
# __import__(name)
importlib.import_module(name)
module_info = imp.find_module(name)
sys.modules.pop(name)
except ImportError:
exist = False
# Testcase: ResourceWarning
if isinstance(module_info[0], FileType):
module_info[0].close()
mpath = module_info[1]
if exist and mpath is not None:
if ('site-packages' in mpath or
'dist-packages' in mpath or
'bin/' in mpath and mpath.endswith('.py')):
exist = False
elif ((sys.prefix not in mpath) and
(sys.base_exec_prefix not in mpath) and
(sys.base_prefix not in mpath)):
exist = mpath
return exist
def get_installed_pkgs_detail():
"""
HACK: bugfix of the original pigar get_installed_pkgs_detail
Get mapping for import top level name
and install package name with version.
"""
mapping = dict()
for path in sys.path:
if os.path.isdir(path) and path.rstrip('/').endswith(
('site-packages', 'dist-packages')):
new_mapping = _search_path(path)
# BUGFIX:
# override with previous, just like python resolves imports, the first match is the one used.
# unlike the original implementation, where the last one is used.
new_mapping.update(mapping)
mapping = new_mapping
# HACK: prefer tensorflow_gpu over tensorflow
if 'tensorflow_gpu' in new_mapping:
new_mapping['tensorflow'] = new_mapping['tensorflow_gpu']
return mapping
def _search_path(path):
mapping = dict()
for file in os.listdir(path):
# Install from PYPI.
if fnmatch.fnmatch(file, '*-info'):
top_level = os.path.join(path, file, 'top_level.txt')
pkg_name, version = file.split('-')[:2]
if version.endswith('dist'):
version = version.rsplit('.', 1)[0]
# Issue for ubuntu: sudo pip install xxx
elif version.endswith('egg'):
version = version.rsplit('.', 1)[0]
mapping[pkg_name] = (pkg_name, version)
if not os.path.isfile(top_level):
continue
with open(top_level, 'r') as f:
for line in f:
mapping[line.strip()] = (pkg_name, version)
# Install from local and available in GitHub.
elif fnmatch.fnmatch(file, '*-link'):
link = os.path.join(path, file)
if not os.path.isfile(link):
continue
# Link path.
with open(link, 'r') as f:
for line in f:
line = line.strip()
if line != '.':
dev_dir = line
if not dev_dir:
continue
# Egg info path.
info_dir = [_file for _file in os.listdir(dev_dir)
if _file.endswith('egg-info')]
if not info_dir:
continue
info_dir = info_dir[0]
top_level = os.path.join(dev_dir, info_dir, 'top_level.txt')
# Check whether it can be imported.
if not os.path.isfile(top_level):
continue
# Check .git dir.
git_path = os.path.join(dev_dir, '.git')
if os.path.isdir(git_path):
config = parse_git_config(git_path)
url = config.get('remote "origin"', {}).get('url')
if not url:
continue
branch = 'branch "master"'
if branch not in config:
for section in config:
if 'branch' in section:
branch = section
break
if not branch:
continue
branch = branch.split()[1][1:-1]
pkg_name = info_dir.split('.egg')[0]
git_url = 'git+{0}@{1}#egg={2}'.format(url, branch, pkg_name)
with open(top_level, 'r') as f:
for line in f:
mapping[line.strip()] = ('-e', git_url)
return mapping

View File

@ -0,0 +1,118 @@
# -*- coding: utf-8 -*-
from __future__ import print_function, division, absolute_import
import tarfile
import zipfile
import re
import string
import io
class Archive(object):
"""Archive provides a consistent interface for unpacking
compressed file.
"""
def __init__(self, filename, fileobj):
self._filename = filename
self._fileobj = fileobj
self._file = None
self._names = None
self._read = None
@property
def filename(self):
return self._filename
@property
def names(self):
"""If name list is not required, do not get it."""
if self._file is None:
self._prepare()
if not hasattr(self, '_namelist'):
self._namelist = self._names()
return self._namelist
def close(self):
"""Close file object."""
if self._file is not None:
self._file.close()
if hasattr(self, '_namelist'):
del self._namelist
self._filename = self._fileobj = None
self._file = self._names = self._read = None
def read(self, filename):
"""Read one file from archive."""
if self._file is None:
self._prepare()
return self._read(filename)
def unpack(self, to_path):
"""Unpack compressed files to path."""
if self._file is None:
self._prepare()
self._safe_extractall(to_path)
def _prepare(self):
if self._filename.endswith(('.tar.gz', '.tar.bz2', '.tar.xz')):
self._prepare_tarball()
# An .egg file is actually just a .zip file
# with a different extension, .whl too.
elif self._filename.endswith(('.zip', '.egg', '.whl')):
self._prepare_zip()
else:
raise ValueError("unreadable: {0}".format(self._filename))
def _safe_extractall(self, to_path='.'):
unsafe = []
for name in self.names:
if not self.is_safe(name):
unsafe.append(name)
if unsafe:
raise ValueError("unsafe to unpack: {}".format(unsafe))
self._file.extractall(to_path)
def _prepare_zip(self):
self._file = zipfile.ZipFile(self._fileobj)
self._names = self._file.namelist
self._read = self._file.read
def _prepare_tarball(self):
# tarfile has no read method
def _read(filename):
f = self._file.extractfile(filename)
return f.read()
self._file = tarfile.open(mode='r:*', fileobj=self._fileobj)
self._names = self._file.getnames
self._read = _read
def is_safe(self, filename):
return not (filename.startswith(("/", "\\")) or
(len(filename) > 1 and filename[1] == ":" and
filename[0] in string.ascii_letter) or
re.search(r"[.][.][/\\]", filename))
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
self.close()
def top_level(url, data):
"""Read top level names from compressed file."""
sb = io.BytesIO(data)
txt = None
with Archive(url, sb) as archive:
file = None
for name in archive.names:
if name.lower().endswith('top_level.txt'):
file = name
break
if file:
txt = archive.read(file).decode('utf-8')
sb.close()
return [name.replace('/', '.') for name in txt.splitlines()] if txt else []

View File

@ -0,0 +1,144 @@
# -*- coding: utf-8 -*-
from __future__ import print_function, division, absolute_import
import os
import re
import sys
import difflib
PY32 = sys.version_info[:2] == (3, 2)
if sys.version_info[0] == 3:
binary_type = bytes
else:
binary_type = str
class Dict(dict):
"""Convert dict key object to attribute."""
def __init__(self, *args, **kwargs):
super(Dict, self).__init__(*args, **kwargs)
def __getattr__(self, name):
try:
return self[name]
except KeyError:
raise AttributeError('"{0}"'.format(name))
def __setattr__(self, name, value):
self[name] = value
def parse_reqs(fpath):
pkg_v_re = re.compile(r'^(?P<pkg>[^><==]+)[><==]{,2}(?P<version>.*)$')
"""Parse requirements file."""
reqs = dict()
with open(fpath, 'r') as f:
for line in f:
if line.startswith('#'):
continue
m = pkg_v_re.match(line.strip())
if m:
d = m.groupdict()
reqs[d['pkg'].strip()] = d['version'].strip()
return reqs
def cmp_to_key(cmp_func):
"""Convert a cmp=fcuntion into a key=function."""
class K(object):
def __init__(self, obj, *args):
self.obj = obj
def __lt__(self, other):
return cmp_func(self.obj, other.obj) < 0
def __gt__(self, other):
return cmp_func(self.obj, other.obj) > 0
def __eq__(self, other):
return cmp_func(self.obj, other.obj) == 0
return K
def compare_version(version1, version2):
"""Compare version number, such as 1.1.1 and 1.1b2.0."""
v1, v2 = list(), list()
for item in version1.split('.'):
if item.isdigit():
v1.append(int(item))
else:
v1.extend([i for i in _group_alnum(item)])
for item in version2.split('.'):
if item.isdigit():
v2.append(int(item))
else:
v2.extend([i for i in _group_alnum(item)])
while v1 and v2:
item1, item2 = v1.pop(0), v2.pop(0)
if item1 > item2:
return 1
elif item1 < item2:
return -1
if v1:
return 1
elif v2:
return -1
return 0
def _group_alnum(s):
tmp = list()
flag = 1 if s[0].isdigit() else 0
for c in s:
if c.isdigit():
if flag == 0:
yield ''.join(tmp)
tmp = list()
flag = 1
tmp.append(c)
elif c.isalpha():
if flag == 1:
yield int(''.join(tmp))
tmp = list()
flag = 0
tmp.append(c)
last = ''.join(tmp)
yield (int(last) if flag else last)
def parse_git_config(path):
"""Parse git config file."""
config = dict()
section = None
with open(os.path.join(path, 'config'), 'r') as f:
for line in f:
line = line.strip()
if line.startswith('['):
section = line[1: -1].strip()
config[section] = dict()
elif section:
key, value = line.replace(' ', '').split('=')
config[section][key] = value
return config
def lines_diff(lines1, lines2):
"""Show difference between lines."""
is_diff = False
diffs = list()
for line in difflib.ndiff(lines1, lines2):
if not is_diff and line[0] in ('+', '-'):
is_diff = True
diffs.append(line)
return is_diff, diffs