Initial release

This commit is contained in:
allegroai
2019-10-25 22:28:44 +03:00
parent 4d808bedc5
commit 224a709f40
101 changed files with 26263 additions and 0 deletions

0
tests/__init__.py Normal file
View File

51
tests/conftest.py Normal file
View File

@@ -0,0 +1,51 @@
from argparse import Namespace
from contextlib import contextmanager
import pytest
import yaml
from pathlib2 import Path
PROJECT_ROOT = Path(__file__).resolve().parents[1]
@pytest.fixture(scope='function')
def run_trains_agent(script_runner):
""" Execute trains_agent agent app in subprocess and return stdout as a string.
Args:
script_runner (object): a pytest plugin for testing python scripts
installed via console_scripts entry point of setup.py.
It can run the scripts under test in a separate process or using the interpreter that's running
the test suite. The former mode ensures that the script will run in an environment that
is identical to normal execution whereas the latter one allows much quicker test runs during development
while simulating the real runs as muh as possible.
For more details: https://pypi.python.org/pypi/pytest-console-scripts
Returns:
string: The return value. stdout output
"""
def _method(*args):
trains_agent_file = str(PROJECT_ROOT / "trains_agent.sh")
ret = script_runner.run(trains_agent_file, *args)
return ret
return _method
@pytest.fixture(scope='function')
def trains_agentyaml(tmpdir):
@contextmanager
def _method(template_file):
file = tmpdir.join("trains_agent.yaml")
with (PROJECT_ROOT / "tests/templates" / template_file).open() as f:
code = yaml.load(f)
yield Namespace(code=code, file=file.strpath)
file.write(yaml.dump(code))
return _method
# class Test(object):
# def yaml_file(self, tmpdir, template_file):
# file = tmpdir.join("trains_agent.yaml")
# with open(template_file) as f:
# test_object = yaml.load(f)
# self.let(test_object)
# file.write(yaml.dump(test_object))
# return file.strpath

View File

View File

@@ -0,0 +1,35 @@
import pytest
from trains_agent.helper.repo import VCS
@pytest.mark.parametrize(
["url", "expected"],
(
("a", None),
("foo://a/b", None),
("foo://a/b/", None),
("https://a/b/", None),
("https://example.com/a/b", None),
("https://example.com/a/b/", None),
("ftp://example.com/a/b", None),
("ftp://example.com/a/b/", None),
("github.com:foo/bar.git", "https://github.com/foo/bar.git"),
("git@github.com:foo/bar.git", "https://github.com/foo/bar.git"),
("bitbucket.org:foo/bar.git", "https://bitbucket.org/foo/bar.git"),
("hg@bitbucket.org:foo/bar.git", "https://bitbucket.org/foo/bar.git"),
("ssh://bitbucket.org/foo/bar.git", "https://bitbucket.org/foo/bar.git"),
("ssh://git@github.com/foo/bar.git", "https://github.com/foo/bar.git"),
("ssh://user@github.com/foo/bar.git", "https://user@github.com/foo/bar.git"),
("ssh://git:password@github.com/foo/bar.git", "https://git:password@github.com/foo/bar.git"),
("ssh://user:password@github.com/foo/bar.git", "https://user:password@github.com/foo/bar.git"),
("ssh://hg@bitbucket.org/foo/bar.git", "https://bitbucket.org/foo/bar.git"),
("ssh://user@bitbucket.org/foo/bar.git", "https://user@bitbucket.org/foo/bar.git"),
("ssh://hg:password@bitbucket.org/foo/bar.git", "https://hg:password@bitbucket.org/foo/bar.git"),
("ssh://user:password@bitbucket.org/foo/bar.git", "https://user:password@bitbucket.org/foo/bar.git"),
),
)
def test(url, expected):
result = VCS.resolve_ssh_url(url)
expected = expected or url
assert result == expected

View File

@@ -0,0 +1,76 @@
import pytest
from furl import furl
from trains_agent.helper.package.translator import RequirementsTranslator
@pytest.mark.parametrize(
"line",
(
furl()
.set(
scheme=scheme,
host=host,
path=path,
query=query,
fragment=fragment,
port=port,
username=username,
password=password,
)
.url
for scheme in ("http", "https", "ftp")
for host in ("a", "example.com")
for path in (None, "/", "a", "a/", "a/b", "a/b/", "a b", "a b ")
for query in (None, "foo", "foo=3", "foo=3&bar")
for fragment in (None, "foo")
for port in (None, 1337)
for username in (None, "", "user")
for password in (None, "", "password")
),
)
def test_supported(line):
assert "://" in line
assert RequirementsTranslator.is_supported_link(line)
@pytest.mark.parametrize(
"line",
[
"pytorch",
"foo",
"foo1",
"bar",
"bar1",
"foo-bar",
"foo-bar1",
"foo-bar-1",
"foo_bar",
"foo_bar1",
"foo_bar_1",
" https://a",
" https://a/b",
" http://a",
" http://a/b",
" ftp://a/b",
"file://a/b",
"ssh://a/b",
"foo://a/b",
"git//a/b",
"git+https://a/b",
"https+git://a/b",
"git+http://a/b",
"http+git://a/b",
"",
" ",
"-e ",
"-e x",
"-e http://a",
"-e http://a/b",
"-e https://a",
"-e https://a/b",
"-e file://a/b",
],
)
def test_not_supported(line):
assert not RequirementsTranslator.is_supported_link(line)

View File

@@ -0,0 +1,33 @@
import attr
import pytest
import requests
from furl import furl
import six
from trains_agent.helper.package.pytorch import PytorchRequirement
@attr.s
class PytorchURLWheel(object):
os = attr.ib()
cuda = attr.ib()
python = attr.ib()
pytorch = attr.ib()
url = attr.ib()
wheels = [
PytorchURLWheel(os=os, cuda=cuda, python=python, pytorch=pytorch_version, url=url)
for os, os_d in PytorchRequirement.MAP.items()
for cuda, cuda_d in os_d.items()
if isinstance(cuda_d, dict)
for python, python_d in cuda_d.items()
if isinstance(python_d, dict)
for pytorch_version, url in python_d.items()
if isinstance(url, six.string_types) and furl(url).scheme
]
@pytest.mark.parametrize('wheel', wheels, ids=[','.join(map(str, attr.astuple(wheel))) for wheel in wheels])
def test_map(wheel):
assert requests.head(wheel.url).ok

View File

@@ -0,0 +1,75 @@
import pytest
from trains_agent.helper.repo import Git
NO_CHANGE = object()
def param(url, expected, user=False, password=False):
"""
Helper function for creating parametrization arguments.
:param url: input url
:param expected: expected output URL or NO_CHANGE if the same as input URL
:param user: Add `agent.git_user=USER` to config
:param password: Add `agent.git_password=PASSWORD` to config
"""
expected_repr = "NO_CHANGE" if expected is NO_CHANGE else None
user = "USER" if user else None
password = "PASSWORD" if password else None
return pytest.param(
url,
expected,
user,
password,
id="-".join(filter(None, (url, user, password, expected_repr))),
)
@pytest.mark.parametrize(
"url,expected,user,password",
[
param("https://bitbucket.org/company/repo", NO_CHANGE),
param("https://bitbucket.org/company/repo", NO_CHANGE, user=True),
param("https://bitbucket.org/company/repo", NO_CHANGE, password=True),
param(
"https://bitbucket.org/company/repo", NO_CHANGE, user=True, password=True
),
param("https://user@bitbucket.org/company/repo", NO_CHANGE),
param("https://user@bitbucket.org/company/repo", NO_CHANGE, user=True),
param("https://user@bitbucket.org/company/repo", NO_CHANGE, password=True),
param(
"https://user@bitbucket.org/company/repo",
"https://USER:PASSWORD@bitbucket.org/company/repo",
user=True,
password=True,
),
param("https://user:password@bitbucket.org/company/repo", NO_CHANGE),
param("https://user:password@bitbucket.org/company/repo", NO_CHANGE, user=True),
param(
"https://user:password@bitbucket.org/company/repo", NO_CHANGE, password=True
),
param(
"https://user:password@bitbucket.org/company/repo",
NO_CHANGE,
user=True,
password=True,
),
param("ssh://git@bitbucket.org/company/repo", NO_CHANGE),
param("ssh://git@bitbucket.org/company/repo", NO_CHANGE, user=True),
param("ssh://git@bitbucket.org/company/repo", NO_CHANGE, password=True),
param(
"ssh://git@bitbucket.org/company/repo", NO_CHANGE, user=True, password=True
),
param("git@bitbucket.org:company/repo.git", NO_CHANGE),
param("git@bitbucket.org:company/repo.git", NO_CHANGE, user=True),
param("git@bitbucket.org:company/repo.git", NO_CHANGE, password=True),
param(
"git@bitbucket.org:company/repo.git", NO_CHANGE, user=True, password=True
),
],
)
def test(url, user, password, expected):
config = {"agent": {"git_user": user, "git_pass": password}}
result = Git.add_auth(config, url)
expected = result if expected is NO_CHANGE else expected
assert result == expected

View File

@@ -0,0 +1,254 @@
"""
Test handling of jupyter notebook tasks.
Logging is enabled in `trains_agent/tests/pytest.ini`. Search for `pytest live logging` for more info.
"""
import logging
import re
import select
import subprocess
import time
from contextlib import contextmanager
from typing import Iterator, ContextManager, Sequence, IO, Text
from uuid import uuid4
from trains_agent.backend_api.services.tasks import Script
from trains_agent.backend_api.session.client import APIClient
from pathlib2 import Path
from pytest import fixture
from trains_agent.helper.process import Argv
logging.getLogger("urllib3").setLevel(logging.CRITICAL)
log = logging.getLogger(__name__)
DEFAULT_TASK_ARGS = {"type": "testing", "name": "test", "input": {"view": {}}}
HERE = Path(__file__).resolve().parent
SHORT_TIMEOUT = 30
@fixture(scope="session")
def client():
return APIClient(api_version="2.2")
@contextmanager
def create_task(client, **params):
"""
Create task in backend
"""
log.info("creating new task")
task = client.tasks.create(**params)
try:
yield task
finally:
log.info("deleting task, id=%s", task.id)
task.delete(force=True)
def select_read(file_obj, timeout):
return select.select([file_obj], [], [], timeout)[0]
def run_task(task):
return Argv("trains_agent", "--debug", "worker", "execute", "--id", task.id)
@contextmanager
def iterate_output(timeout, command):
# type: (int, Argv) -> ContextManager[Iterator[Text]]
"""
Run `command` in a subprocess and return a contextmanager of an iterator
over its output's lines. If `timeout` seconds have passed, iterator ends and
the process is killed.
:param timeout: maximum amount of time to wait for command to end
:param command: command to run
"""
log.info("running: %s", command)
process = command.call_subprocess(
subprocess.Popen, stdout=subprocess.PIPE, stderr=subprocess.STDOUT
)
try:
yield _iterate_output(timeout, process)
finally:
status = process.poll()
if status is not None:
log.info("command %s terminated, status: %s", command, status)
else:
log.info("killing command %s", command)
process.kill()
def _iterate_output(timeout, process):
# type: (int, subprocess.Popen) -> Iterator[Text]
"""
Return an iterator over process's output lines.
over its output's lines.
If `timeout` seconds have passed, iterator ends and the process is killed.
:param timeout: maximum amount of time to wait for command to end
:param process: process to iterate over its output lines
"""
start = time.time()
exit_loop = [] # type: Sequence[IO]
def loop_helper(file_obj):
# type: (IO) -> Sequence[IO]
diff = timeout - (time.time() - start)
if diff <= 0:
return exit_loop
return select_read(file_obj, timeout=diff)
buffer = ""
for output, in iter(lambda: loop_helper(process.stdout), exit_loop):
try:
added = output.read(1024).decode("utf8")
except EOFError:
if buffer:
yield buffer
return
buffer += added
lines = buffer.split("\n", 1)
while len(lines) > 1:
line, buffer = lines
log.debug("--- %s", line)
yield line
lines = buffer.split("\n", 1)
def search_lines(lines, search_for, error):
# type: (Iterator[Text], Text, Text) -> None
"""
Fail test if `search_for` string appears nowhere in `lines`.
Consumes `lines` up to point where `search_for` is found.
:param lines: lines to search in
:param search_for: string to search lines for
:param error: error to show if not found
"""
for line in lines:
if search_for in line:
break
else:
assert False, error
def search_lines_pattern(lines, pattern, error):
# type: (Iterator[Text], Text, Text) -> None
"""
Like `search_lines` but searches for a pattern.
:param lines: lines to search in
:param pattern: pattern to search lines for
:param error: error to show if not found
"""
for line in lines:
if re.search(pattern, line):
break
else:
assert False, error
def test_entry_point_warning(client):
"""
non-empty script.entry_point should output a warning
"""
with create_task(
client,
script=Script(diff="print('hello')", entry_point="foo.py", repository=""),
**DEFAULT_TASK_ARGS
) as task, iterate_output(SHORT_TIMEOUT, run_task(task)) as output:
for line in output:
if "found non-empty script.entry_point" in line:
break
else:
assert False, "did not find warning in output"
def test_run_no_dirs(client):
"""
The arbitrary `code` directory should be selected when there is no `script.repository`
"""
uuid = uuid4().hex
script = "print('{}')".format(uuid)
with create_task(
client,
script=Script(diff=script, entry_point="", repository="", working_dir=""),
**DEFAULT_TASK_ARGS
) as task, iterate_output(SHORT_TIMEOUT, run_task(task)) as output:
search_lines(
output,
"found literal script",
"task was not recognized as a literal script",
)
search_lines_pattern(
output,
r"selected execution directory:.*code",
r"did not selected empty `code` dir as execution dir",
)
search_lines(output, uuid, "did not find uuid {!r} in output".format(uuid))
def test_run_working_dir(client):
"""
Literal script tasks should respect `working_dir`
"""
uuid = uuid4().hex
script = "print('{}')".format(uuid)
with create_task(
client,
script=Script(
diff=script,
entry_point="",
repository="git@bitbucket.org:seematics/roee_test_git.git",
working_dir="space dir",
),
**DEFAULT_TASK_ARGS
) as task, iterate_output(120, run_task(task)) as output:
search_lines(
output,
"found literal script",
"task was not recognized as a literal script",
)
search_lines_pattern(
output,
r"selected execution directory:.*space dir",
r"did not selected working_dir as set in execution_info",
)
search_lines(output, uuid, "did not find uuid {!r} in output".format(uuid))
def test_regular_task(client):
"""
Test a plain old task
"""
with create_task(
client,
script=Script(
entry_point="noop.py",
repository="git@bitbucket.org:seematics/roee_test_git.git",
),
**DEFAULT_TASK_ARGS
) as task, iterate_output(SHORT_TIMEOUT, run_task(task)) as output:
message = "Done"
search_lines(
output, message, "did not reach dummy output message {}".format(message)
)
def test_regular_task_nested(client):
"""
`entry_point` should be relative to `working_dir` if present
"""
with create_task(
client,
script=Script(
entry_point="noop_nested.py",
working_dir="no_reqs",
repository="git@bitbucket.org:seematics/roee_test_git.git",
),
**DEFAULT_TASK_ARGS
) as task, iterate_output(SHORT_TIMEOUT, run_task(task)) as output:
message = "Done"
search_lines(
output, message, "did not reach dummy output message {}".format(message)
)

11
tests/pytest.ini Normal file
View File

@@ -0,0 +1,11 @@
[pytest]
python_files=*.py
script_launch_mode = subprocess
env =
PYTHONIOENCODING=utf-8
log_cli = 1
log_cli_level = DEBUG
log_cli_format = [%(name)s] [%(asctime)s] [%(levelname)8s] %(message)s
log_print = False
log_cli_date_format=%Y-%m-%d %H:%M:%S
addopts = -p no:warnings

2
tests/requirements.txt Normal file
View File

@@ -0,0 +1,2 @@
pytest
-r ../requirements.txt

View File

View File

@@ -0,0 +1,13 @@
import os
def pytest_addoption(parser):
parser.addoption('--cpu', action='store_true')
def pytest_configure(config):
if not config.option.cpu:
return
os.environ['PATH'] = ':'.join(p for p in os.environ['PATH'].split(':') if 'cuda' not in p)
os.environ['CUDA_VERSION'] = ''
os.environ['CUDNN_VERSION'] = ''

View File

@@ -0,0 +1,350 @@
from __future__ import unicode_literals
import re
import subprocess
from itertools import chain
from os import path
from os.path import sep
from sys import platform as sys_platform
import pytest
import requirements
from trains_agent.commands.worker import Worker
from trains_agent.helper.package.pytorch import PytorchRequirement
from trains_agent.helper.package.requirements import RequirementsManager, \
RequirementSubstitution, MarkerRequirement
from trains_agent.helper.process import get_bash_output
from trains_agent.session import Session
_cuda_based_packages_hack = ('seematics.caffe', 'lightnet')
def old_get_suffix(session):
cuda_version = session.config['agent.cuda_version']
cudnn_version = session.config['agent.cudnn_version']
if cuda_version and cudnn_version:
nvcc_ver = cuda_version.strip()
cudnn_ver = cudnn_version.strip()
else:
if sys_platform == 'win32':
nvcc_ver = subprocess.check_output('nvcc --version'.split()).decode('utf-8'). \
replace('\r', '').split('\n')
else:
nvcc_ver = subprocess.check_output(
'nvcc --version'.split()).decode('utf-8').split('\n')
nvcc_ver = [l for l in nvcc_ver if 'release ' in l]
nvcc_ver = nvcc_ver[0][nvcc_ver[0].find(
'release ') + len('release '):][:3]
nvcc_ver = str(int(float(nvcc_ver) * 10))
if sys_platform == 'win32':
cuda_lib = subprocess.check_output('where nvcc'.split()).decode('utf-8'). \
replace('\r', '').split('\n')[0]
cudnn_h = sep.join(cuda_lib.split(
sep)[:-2] + ['include', 'cudnn.h'])
else:
cuda_lib = subprocess.check_output(
'which nvcc'.split()).decode('utf-8').split('\n')[0]
cudnn_h = path.join(
sep, *(cuda_lib.split(sep)[:-2] + ['include', 'cudnn.h']))
cudnn_mj, cudnn_mi = None, None
for l in open(cudnn_h, 'r'):
if 'CUDNN_MAJOR' in l:
cudnn_mj = l.split()[-1]
if 'CUDNN_MINOR' in l:
cudnn_mi = l.split()[-1]
if cudnn_mj and cudnn_mi:
break
cudnn_ver = cudnn_mj + ('0' if not cudnn_mi else cudnn_mi)
# build cuda + cudnn version suffix
# make sure these are integers, someone else will catch the exception
pkg_suffix_ver = '.post' + \
str(int(nvcc_ver)) + '.dev' + str(int(cudnn_ver))
return pkg_suffix_ver, nvcc_ver, cudnn_ver
def old_replace(session, line):
try:
cuda_ver_suffix, cuda_ver, cuda_cudnn_ver = old_get_suffix(session)
except Exception:
return line
if line.lstrip().startswith('#'):
return line
for package_name in _cuda_based_packages_hack:
if package_name not in line:
continue
try:
line_lstrip = line.lstrip()
if line_lstrip.startswith('http://') or line_lstrip.startswith('https://'):
pos = line.find(package_name) + len(package_name)
# patch line with specific version
line = line[:pos] + \
line[pos:].replace('-cp', cuda_ver_suffix + '-cp', 1)
else:
# this is a pypi package
tokens = line.replace('=', ' ').replace('<', ' ').replace('>', ' ').replace(';', ' '). \
replace('!', ' ').split()
if package_name != tokens[0]:
# how did we get here, probably a mistake
found_cuda_based_package = False
continue
version_number = None
if len(tokens) > 1:
# get the package version info
test_version_number = tokens[1]
# check if we have a valid version, i.e. does not contain post/dev
version_number = '.'.join([v for v in test_version_number.split('.')
if v and '0' <= v[0] <= '9'])
if version_number != test_version_number:
raise ValueError()
# we have no version, but we have to have one
if not version_number:
# get the latest version from the extra index list
pip_search_cmd = ['pip', 'search']
if Worker._pip_extra_index_url:
pip_search_cmd.extend(
chain.from_iterable(('-i', x) for x in Worker._pip_extra_index_url))
pip_search_cmd += [package_name]
pip_search_output = get_bash_output(
' '.join(pip_search_cmd), strip=True)
version_number = pip_search_output.split(package_name)[1]
version_number = version_number.replace(
'(', ' ').replace(')', ' ').split()[0]
version_number = '.'.join([v for v in version_number.split('.')
if v and '0' <= v[0] <= '9'])
if not version_number:
# somewhere along the way we failed
raise ValueError()
package_name_version = package_name + '==' + version_number + cuda_ver_suffix
if version_number in line:
# make sure we have the specific version not >=
tokens = line.split(';')
line = ';'.join([package_name_version] + tokens[1:])
else:
# add version to the package_name
line = line.replace(package_name, package_name_version, 1)
# print('pip install %s using CUDA v%s CuDNN v%s' % (package_name, cuda_ver, cuda_cudnn_ver))
except ValueError:
pass
# print('Warning! could not find installed CUDA/CuDNN version for %s, '
# 'using original requirements line: %s' % (package_name, line))
# add the current line into the cuda requirements list
return line
win_condition = 'sys_platform != "win_32"'
versions = ('1', '1.4', '1.4.9', '1.5.3.dev0',
'1.5.3.dev3', '43.1.2.dev0.post1')
def normalize(result):
return result and re.sub(' ?; ?', ';', result)
def parse_one(requirement):
return MarkerRequirement(next(requirements.parse(requirement)))
def compare(manager, current_session, pytest_config, requirement, expected):
try:
res1 = normalize(manager._replace_one(parse_one(requirement)))
except ValueError:
res1 = None
res2 = old_replace(current_session, requirement)
if res2 == requirement:
res2 = None
res2 = normalize(res2)
expected = normalize(expected)
if pytest_config.option.cpu:
assert res1 is None
else:
assert res1 == expected
if requirement not in FAILURES:
assert res1 == res2
return res1
ARG_VERSIONED = pytest.mark.parametrize('arg', (
'nothing{op}{version}{extra}',
'something_else{op}{version}{extra}',
'something-else{op}{version}{extra}',
'something.else{op}{version}{extra}',
'seematics.caffe{op}{version}{extra}',
'lightnet{op}{version}{extra}',
))
OP = pytest.mark.parametrize('op', ('==', '<=', '>='))
VERSION = pytest.mark.parametrize('version', versions)
EXTRA = pytest.mark.parametrize('extra', ('', ' ; ' + win_condition))
ARG_PLAIN = pytest.mark.parametrize('arg', (
'nothing',
'something_else',
'something-else',
'something.else',
'seematics.caffe',
'lightnet',
'https://s3.amazonaws.com/seematics-pip/public/windows64bit/static/seematics.caffe-1.0.3-cp35-cp35m-win_amd64.whl',
'https://s3.amazonaws.com/seematics-pip/public/windows64bit/static/seematics.caffe-1.0.3-cp27-cp27m-win_amd64.whl',
'https://s3.amazonaws.com/seematics-pip/public/static/seematics.caffe-1.0.3-cp35-cp35m-linux_x86_64.whl',
'https://s3.amazonaws.com/seematics-pip/public/static/seematics.caffe-1.0.3-cp27-cp27mu-linux_x86_64.whl',
'https://s3.amazonaws.com/seematics-pip/public/seematics.config-1.0.2-py2.py3-none-any.whl',
'https://s3.amazonaws.com/seematics-pip/public/seematics.api-1.0.2-py2.py3-none-any.whl',
'https://s3.amazonaws.com/seematics-pip/public/seematics.sdk-1.1.0-py2.py3-none-any.whl',
))
@ARG_VERSIONED
@OP
@VERSION
@EXTRA
def test_with_version(manager, current_session, pytestconfig, arg, op, version, extra):
requirement = arg.format(**locals())
# expected = EXPECTED.get((arg, op, version))
expected = get_expected(current_session, (arg, op, version))
expected = expected and expected + extra
compare(manager, current_session, pytestconfig, requirement, expected)
@ARG_PLAIN
@EXTRA
def test_plain(arg, current_session, pytestconfig, extra, manager):
# expected = EXPECTED.get(arg)
expected = get_expected(current_session, arg)
expected = expected and expected + extra
compare(manager, current_session, pytestconfig, arg + extra, expected)
def get_expected(session, key):
result = EXPECTED.get(key)
cuda_version, cudnn_version = session.config['agent.cuda_version'], session.config['agent.cudnn_version']
suffix = '.post{cuda_version}.dev{cudnn_version}'.format(**locals()) \
if (cuda_version and cudnn_version) \
else ''
return result and result.format(suffix=suffix)
@ARG_VERSIONED
@OP
@VERSION
@EXTRA
def test_str_versioned(arg, op, version, extra):
requirement = arg.format(**locals())
assert normalize(str(parse_one(requirement))) == normalize(requirement)
@ARG_PLAIN
@EXTRA
def test_str_plain(arg, extra, manager):
requirement = arg.format(**locals())
assert normalize(str(parse_one(requirement))) == normalize(requirement)
@pytest.fixture(scope='session')
def current_session(pytestconfig):
session = Session()
if not pytestconfig.option.cpu:
return session
session.config['agent.cuda_version'] = None
session.config['agent.cudnn_version'] = None
return session
@pytest.fixture(scope='session')
def manager(current_session):
manager = RequirementsManager(current_session.config)
for requirement in (PytorchRequirement, ):
manager.register(requirement)
return manager
SPECS = (
# plain
'',
# greater than
'>=0', '>0.1', '>=0.1', '>0.1.2', '>=0.1.2', '>0.1.2.post30', '>=0.1.3.post30', '>0.post30.dev40',
'>=0.post30.dev40', '>0.post30-dev40', '>=0.post30-dev40', '>0.0', '>=0.0', '>0.3', '>=0.3', '>=0.4',
# smaller than
'<4', '<4.0', '<4.1.0', '<4.1.0.post80.dev60', '<4.1.0-post80.dev60', '<3', '<2.0', '<1.0.3', '<=4', '<=4.0',
'<=4.1.0', '<=4.1.0.post80.dev60', '<=4.1.0-post80.dev60', '<=3', '<=2.0', '<=1.0.3',
# equals
'==0.4.0',
# equals and
'==0.4.0,>=0', '==0.4.0,>=0.1', '==0.4.0,>0.1', '==0.4.0,>=0.1.2', '==0.4.0,>0.1.2', '==0.4.0,<4', '==0.4.0,<=4',
'==0.4.0,<4.0', '==0.4.0,<=4.0', '==0.4.0,<4.0.2',
# smaller and greater
'>=0,<4', '>=0,<4.0', '>0.1,<1', '>0.1,<1.1.2', '>0.1,<1.1.2', '>=0,<=4', '>=0,<4.0',
'>=0.1,<1', '>0.1,<=1.1.2', '>=0.1,<1.1.2',
)
@pytest.mark.parametrize('package_manager', ('pip', 'conda'))
@pytest.mark.parametrize('os', ('linux', 'windows', 'macos'))
@pytest.mark.parametrize('cuda', ('80', '90', '91'))
@pytest.mark.parametrize('python', ('2.7', '3.5', '3.6'))
@pytest.mark.parametrize('spec', SPECS)
@pytest.mark.parametrize('condition', ('', ' ; ' + win_condition))
def test_pytorch_success(manager, package_manager, os, cuda, python, spec, condition):
pytorch_handler = manager.handlers[-1]
pytorch_handler.package_manager = package_manager
pytorch_handler.os = os
pytorch_handler.cuda = cuda_ver = 'cuda{}'.format(cuda)
pytorch_handler.python = python_ver = 'python{}'.format(python)
req = 'torch{}{}'.format(spec, condition)
expected = pytorch_handler.MAP[package_manager][os][cuda_ver][python_ver]
if isinstance(expected, Exception):
with pytest.raises(type(expected)):
manager._replace_one(parse_one(req))
else:
expected = expected['0.4.0']
result = manager._replace_one(parse_one(req))
assert result == expected
get_pip_version = RequirementSubstitution.get_pip_version
EXPECTED = {
'https://s3.amazonaws.com/seematics-pip/public/windows64bit/static/seematics.caffe-1.0.3-cp35-cp35m-win_amd64.whl':
'https://s3.amazonaws.com/seematics-pip/public/windows64bit/static/seematics.caffe-1.0.3{suffix}-cp35'
'-cp35m-win_amd64.whl',
'https://s3.amazonaws.com/seematics-pip/public/windows64bit/static/seematics.caffe-1.0.3-cp27-cp27m-win_amd64.whl':
'https://s3.amazonaws.com/seematics-pip/public/windows64bit/static/seematics.caffe-1.0.3{suffix}-cp27'
'-cp27m-win_amd64.whl',
'https://s3.amazonaws.com/seematics-pip/public/static/seematics.caffe-1.0.3-cp35-cp35m-linux_x86_64.whl':
'https://s3.amazonaws.com/seematics-pip/public/static/seematics.caffe-1.0.3{suffix}-cp35-cp35m'
'-linux_x86_64.whl',
'https://s3.amazonaws.com/seematics-pip/public/static/seematics.caffe-1.0.3-cp27-cp27mu-linux_x86_64.whl':
'https://s3.amazonaws.com/seematics-pip/public/static/seematics.caffe-1.0.3{suffix}-cp27-cp27mu'
'-linux_x86_64.whl',
'seematics.caffe': 'seematics.caffe=={}{{suffix}}'.format(get_pip_version('seematics.caffe')),
'lightnet': 'lightnet=={}{{suffix}}'.format(get_pip_version('lightnet')),
('seematics.caffe{op}{version}{extra}', '<=', '1'): 'seematics.caffe==1{suffix}',
('seematics.caffe{op}{version}{extra}', '<=', '1.4'): 'seematics.caffe==1.4{suffix}',
('seematics.caffe{op}{version}{extra}', '<=', '1.4.9'): 'seematics.caffe==1.4.9{suffix}',
('seematics.caffe{op}{version}{extra}', '==', '1'): 'seematics.caffe==1{suffix}',
('seematics.caffe{op}{version}{extra}', '==', '1.4'): 'seematics.caffe==1.4{suffix}',
('seematics.caffe{op}{version}{extra}', '==', '1.4.9'): 'seematics.caffe==1.4.9{suffix}',
('seematics.caffe{op}{version}{extra}', '>=', '1'): 'seematics.caffe==1{suffix}',
('seematics.caffe{op}{version}{extra}', '>=', '1.4'): 'seematics.caffe==1.4{suffix}',
('seematics.caffe{op}{version}{extra}', '>=', '1.4.9'): 'seematics.caffe==1.4.9{suffix}',
('lightnet{op}{version}{extra}', '<=', '1'): 'lightnet==1{suffix}',
('lightnet{op}{version}{extra}', '<=', '1.4'): 'lightnet==1.4{suffix}',
('lightnet{op}{version}{extra}', '<=', '1.4.9'): 'lightnet==1.4.9{suffix}',
('lightnet{op}{version}{extra}', '==', '1'): 'lightnet==1{suffix}',
('lightnet{op}{version}{extra}', '==', '1.4'): 'lightnet==1.4{suffix}',
('lightnet{op}{version}{extra}', '==', '1.4.9'): 'lightnet==1.4.9{suffix}',
('lightnet{op}{version}{extra}', '>=', '1'): 'lightnet==1{suffix}',
('lightnet{op}{version}{extra}', '>=', '1.4'): 'lightnet==1.4{suffix}',
('lightnet{op}{version}{extra}', '>=', '1.4.9'): 'lightnet==1.4.9{suffix}',
}
FAILURES = {
'seematics.caffe ; {}'.format(win_condition),
'lightnet ; {}'.format(win_condition)
}

View File

View File

@@ -0,0 +1,7 @@
def main():
assert 1 / 2 == 0
print('success')
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,8 @@
def main():
if not (1 / 2 == 0.5):
raise ValueError('failure')
print('success')
if __name__ == "__main__":
main()