Python 2.7 support

This commit is contained in:
allegroai 2019-06-13 01:55:36 +03:00
parent dfcd855975
commit ea3b5856fd
8 changed files with 34 additions and 21 deletions

View File

@ -39,7 +39,7 @@ class TensorBoardImage(TensorBoard):
encoded_image_string=image_string)
def on_epoch_end(self, epoch, logs={}):
super().on_epoch_end(epoch, logs)
super(TensorBoardImage, self).on_epoch_end(epoch, logs)
import tensorflow as tf
images = self.validation_data[0] # 0 - data; 1 - labels
img = (255 * images[0].reshape(28, 28)).astype('uint8')

View File

@ -59,6 +59,7 @@ import torchvision.models as models
import copy
from trains import Task
task = Task.init(project_name='examples', task_name='pytorch with matplotlib example', task_type=Task.TaskTypes.testing)
@ -95,10 +96,6 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# with name ``images`` in your current working directory.
# desired size of the output image
STYLE_IMAGE_PATH = "./samples/picasso.jpg"
CONTENT_IMAGE_PATH = "./samples/dancing.jpg"
imsize = 512 if torch.cuda.is_available() else 128 # use small size if no gpu
loader = transforms.Compose([
@ -113,8 +110,8 @@ def image_loader(image_name):
return image.to(device, torch.float)
style_img = image_loader(STYLE_IMAGE_PATH)
content_img = image_loader(CONTENT_IMAGE_PATH)
style_img = image_loader("./samples/picasso.jpg")
content_img = image_loader("./samples/dancing.jpg")
assert style_img.size() == content_img.size(), \
"we need to import style and content images of the same size"
@ -169,7 +166,7 @@ imshow(content_img, title='Content Image')
# computed at the desired layers and because of auto grad, all the
# gradients will be computed. Now, in order to make the content loss layer
# transparent we must define a ``forward`` method that computes the content
# loss and then returns the layers input. The computed loss is saved as a
# loss and then returns the layer's input. The computed loss is saved as a
# parameter of the module.
#
@ -259,7 +256,7 @@ class StyleLoss(nn.Module):
# Now we need to import a pre-trained neural network. We will use a 19
# layer VGG network like the one used in the paper.
#
# PyTorchs implementation of VGG is a module divided into two child
# PyTorch's implementation of VGG is a module divided into two child
# ``Sequential`` modules: ``features`` (containing convolution and pooling layers),
# and ``classifier`` (containing fully connected layers). We will use the
# ``features`` module because we need the output of the individual
@ -299,7 +296,7 @@ class Normalization(nn.Module):
######################################################################
# A ``Sequential`` module contains an ordered list of child modules. For
# instance, ``vgg19.features`` contains a sequence (Conv2d, ReLU, MaxPool2d,
# Conv2d, ReLU) aligned in the right order of depth. We need to add our
# Conv2d, ReLU...) aligned in the right order of depth. We need to add our
# content loss and style loss layers immediately after the convolution
# layer they are detecting. To do this we must create a new ``Sequential``
# module that has content loss and style loss modules correctly inserted.
@ -407,7 +404,7 @@ def get_input_optimizer(input_img):
# Finally, we must define a function that performs the neural transfer. For
# each iteration of the networks, it is fed an updated input and computes
# new losses. We will run the ``backward`` methods of each loss module to
# dynamicaly compute their gradients. The optimizer requires a “closure”
# dynamicaly compute their gradients. The optimizer requires a "closure"
# function, which reevaluates the modul and returns the loss.
#
# We still have one final constraint to address. The network may try to

View File

@ -7,7 +7,7 @@ import six
from ..base import InterfaceBase
from ..setupuploadmixin import SetupUploadMixin
from ...utilities.async_manager import AsyncManagerMixin
from ...utilities.plotly import create_2d_histogram_plot, create_value_matrix, create_3d_surface, \
from ...utilities.plotly_reporter import create_2d_histogram_plot, create_value_matrix, create_3d_surface, \
create_2d_scatter_series, create_3d_scatter_series, create_line_plot, plotly_scatter3d_layout_dict
from ...utilities.py3_interop import AbstractContextManager
from .events import ScalarEvent, VectorEvent, ImageEvent, PlotEvent, ImageEventNoUpload

View File

@ -1,4 +1,7 @@
import json
import six
from . import get_cache_dir
from .defs import SESSION_CACHE_FILE
@ -12,7 +15,8 @@ class SessionCache(object):
@classmethod
def _load_cache(cls):
try:
with (get_cache_dir() / SESSION_CACHE_FILE).open("rt") as fp:
flag = 'rb' if six.PY2 else 'rt'
with (get_cache_dir() / SESSION_CACHE_FILE).open(flag) as fp:
return json.load(fp)
except Exception:
return {}
@ -21,7 +25,8 @@ class SessionCache(object):
def _store_cache(cls, cache):
try:
get_cache_dir().mkdir(parents=True, exist_ok=True)
with (get_cache_dir() / SESSION_CACHE_FILE).open("wt") as fp:
flag = 'wb' if six.PY2 else 'wt'
with (get_cache_dir() / SESSION_CACHE_FILE).open(flag) as fp:
json.dump(cache, fp)
except Exception:
pass

View File

@ -1,3 +1,5 @@
from __future__ import print_function
from pyhocon import ConfigFactory
from pathlib2 import Path
from six.moves.urllib.parse import urlparse, urlunparse

View File

@ -4,7 +4,7 @@ import threading
import weakref
from collections import defaultdict
from logging import ERROR, WARNING, getLogger
from pathlib import Path
from pathlib2 import Path
import cv2
import numpy as np
@ -99,10 +99,11 @@ class PostImportHookPatching(object):
def _patched_call(original_fn, patched_fn):
def _inner_patch(*args, **kwargs):
ident = threading.get_ident()
ident = threading._get_ident() if six.PY2 else threading.get_ident()
if ident in _recursion_guard:
return original_fn(*args, **kwargs)
_recursion_guard[ident] = 1
ret = None
try:
ret = patched_fn(original_fn, *args, **kwargs)
except Exception as ex:

View File

@ -2,8 +2,10 @@ import sys
import cv2
import numpy as np
import six
from six import BytesIO
from ..debugging.log import LoggerRoot
from ..config import running_remotely
@ -36,6 +38,9 @@ class PatchedMatplotlib:
# we support matplotlib version 2.0.0 and above
import matplotlib
if int(matplotlib.__version__.split('.')[0]) < 2:
LoggerRoot.get_base_logger().warning(
'matplotlib binding supports version 2.0 and above, found version {}'.format(
matplotlib.__version__))
return False
if running_remotely():
@ -44,10 +49,13 @@ class PatchedMatplotlib:
import matplotlib.pyplot
sys.modules['matplotlib'].pyplot.switch_backend('agg')
import matplotlib.pyplot as plt
import plotly.tools as tls
from matplotlib import _pylab_helpers
PatchedMatplotlib._patched_original_plot = sys.modules['matplotlib'].pyplot.show
PatchedMatplotlib._patched_original_imshow = sys.modules['matplotlib'].pyplot.imshow
if six.PY2:
PatchedMatplotlib._patched_original_plot = staticmethod(sys.modules['matplotlib'].pyplot.show)
PatchedMatplotlib._patched_original_imshow = staticmethod(sys.modules['matplotlib'].pyplot.imshow)
else:
PatchedMatplotlib._patched_original_plot = sys.modules['matplotlib'].pyplot.show
PatchedMatplotlib._patched_original_imshow = sys.modules['matplotlib'].pyplot.imshow
sys.modules['matplotlib'].pyplot.show = PatchedMatplotlib.patched_show
# sys.modules['matplotlib'].pyplot.imshow = PatchedMatplotlib.patched_imshow
# patch plotly so we know it failed us.
@ -103,7 +111,6 @@ class PatchedMatplotlib:
# noinspection PyBroadException
try:
import matplotlib.pyplot as plt
import plotly.tools as tls
from plotly import optional_imports
from matplotlib import _pylab_helpers
# store the figure object we just created (if it is not already there)
@ -136,7 +143,8 @@ class PatchedMatplotlib:
buffer_ = BytesIO()
plt.savefig(buffer_, format="png", bbox_inches='tight', pad_inches=0)
buffer_.seek(0)
image = cv2.imdecode(np.frombuffer(buffer_.getbuffer(), dtype=np.uint8), cv2.IMREAD_UNCHANGED)
buffer = buffer_.getbuffer() if not six.PY2 else buffer_.getvalue()
image = cv2.imdecode(np.frombuffer(buffer, dtype=np.uint8), cv2.IMREAD_UNCHANGED)
# check if we need to restore the active object
if set_active and not _pylab_helpers.Gcf.get_active():