Python 2.7 support

This commit is contained in:
allegroai 2019-06-13 01:55:36 +03:00
parent dfcd855975
commit ea3b5856fd
8 changed files with 34 additions and 21 deletions

View File

@ -39,7 +39,7 @@ class TensorBoardImage(TensorBoard):
encoded_image_string=image_string) encoded_image_string=image_string)
def on_epoch_end(self, epoch, logs={}): def on_epoch_end(self, epoch, logs={}):
super().on_epoch_end(epoch, logs) super(TensorBoardImage, self).on_epoch_end(epoch, logs)
import tensorflow as tf import tensorflow as tf
images = self.validation_data[0] # 0 - data; 1 - labels images = self.validation_data[0] # 0 - data; 1 - labels
img = (255 * images[0].reshape(28, 28)).astype('uint8') img = (255 * images[0].reshape(28, 28)).astype('uint8')

View File

@ -59,6 +59,7 @@ import torchvision.models as models
import copy import copy
from trains import Task from trains import Task
task = Task.init(project_name='examples', task_name='pytorch with matplotlib example', task_type=Task.TaskTypes.testing) task = Task.init(project_name='examples', task_name='pytorch with matplotlib example', task_type=Task.TaskTypes.testing)
@ -95,10 +96,6 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# with name ``images`` in your current working directory. # with name ``images`` in your current working directory.
# desired size of the output image # desired size of the output image
STYLE_IMAGE_PATH = "./samples/picasso.jpg"
CONTENT_IMAGE_PATH = "./samples/dancing.jpg"
imsize = 512 if torch.cuda.is_available() else 128 # use small size if no gpu imsize = 512 if torch.cuda.is_available() else 128 # use small size if no gpu
loader = transforms.Compose([ loader = transforms.Compose([
@ -113,8 +110,8 @@ def image_loader(image_name):
return image.to(device, torch.float) return image.to(device, torch.float)
style_img = image_loader(STYLE_IMAGE_PATH) style_img = image_loader("./samples/picasso.jpg")
content_img = image_loader(CONTENT_IMAGE_PATH) content_img = image_loader("./samples/dancing.jpg")
assert style_img.size() == content_img.size(), \ assert style_img.size() == content_img.size(), \
"we need to import style and content images of the same size" "we need to import style and content images of the same size"
@ -169,7 +166,7 @@ imshow(content_img, title='Content Image')
# computed at the desired layers and because of auto grad, all the # computed at the desired layers and because of auto grad, all the
# gradients will be computed. Now, in order to make the content loss layer # gradients will be computed. Now, in order to make the content loss layer
# transparent we must define a ``forward`` method that computes the content # transparent we must define a ``forward`` method that computes the content
# loss and then returns the layers input. The computed loss is saved as a # loss and then returns the layer's input. The computed loss is saved as a
# parameter of the module. # parameter of the module.
# #
@ -259,7 +256,7 @@ class StyleLoss(nn.Module):
# Now we need to import a pre-trained neural network. We will use a 19 # Now we need to import a pre-trained neural network. We will use a 19
# layer VGG network like the one used in the paper. # layer VGG network like the one used in the paper.
# #
# PyTorchs implementation of VGG is a module divided into two child # PyTorch's implementation of VGG is a module divided into two child
# ``Sequential`` modules: ``features`` (containing convolution and pooling layers), # ``Sequential`` modules: ``features`` (containing convolution and pooling layers),
# and ``classifier`` (containing fully connected layers). We will use the # and ``classifier`` (containing fully connected layers). We will use the
# ``features`` module because we need the output of the individual # ``features`` module because we need the output of the individual
@ -299,7 +296,7 @@ class Normalization(nn.Module):
###################################################################### ######################################################################
# A ``Sequential`` module contains an ordered list of child modules. For # A ``Sequential`` module contains an ordered list of child modules. For
# instance, ``vgg19.features`` contains a sequence (Conv2d, ReLU, MaxPool2d, # instance, ``vgg19.features`` contains a sequence (Conv2d, ReLU, MaxPool2d,
# Conv2d, ReLU) aligned in the right order of depth. We need to add our # Conv2d, ReLU...) aligned in the right order of depth. We need to add our
# content loss and style loss layers immediately after the convolution # content loss and style loss layers immediately after the convolution
# layer they are detecting. To do this we must create a new ``Sequential`` # layer they are detecting. To do this we must create a new ``Sequential``
# module that has content loss and style loss modules correctly inserted. # module that has content loss and style loss modules correctly inserted.
@ -407,7 +404,7 @@ def get_input_optimizer(input_img):
# Finally, we must define a function that performs the neural transfer. For # Finally, we must define a function that performs the neural transfer. For
# each iteration of the networks, it is fed an updated input and computes # each iteration of the networks, it is fed an updated input and computes
# new losses. We will run the ``backward`` methods of each loss module to # new losses. We will run the ``backward`` methods of each loss module to
# dynamicaly compute their gradients. The optimizer requires a “closure” # dynamicaly compute their gradients. The optimizer requires a "closure"
# function, which reevaluates the modul and returns the loss. # function, which reevaluates the modul and returns the loss.
# #
# We still have one final constraint to address. The network may try to # We still have one final constraint to address. The network may try to

View File

@ -7,7 +7,7 @@ import six
from ..base import InterfaceBase from ..base import InterfaceBase
from ..setupuploadmixin import SetupUploadMixin from ..setupuploadmixin import SetupUploadMixin
from ...utilities.async_manager import AsyncManagerMixin from ...utilities.async_manager import AsyncManagerMixin
from ...utilities.plotly import create_2d_histogram_plot, create_value_matrix, create_3d_surface, \ from ...utilities.plotly_reporter import create_2d_histogram_plot, create_value_matrix, create_3d_surface, \
create_2d_scatter_series, create_3d_scatter_series, create_line_plot, plotly_scatter3d_layout_dict create_2d_scatter_series, create_3d_scatter_series, create_line_plot, plotly_scatter3d_layout_dict
from ...utilities.py3_interop import AbstractContextManager from ...utilities.py3_interop import AbstractContextManager
from .events import ScalarEvent, VectorEvent, ImageEvent, PlotEvent, ImageEventNoUpload from .events import ScalarEvent, VectorEvent, ImageEvent, PlotEvent, ImageEventNoUpload

View File

@ -1,4 +1,7 @@
import json import json
import six
from . import get_cache_dir from . import get_cache_dir
from .defs import SESSION_CACHE_FILE from .defs import SESSION_CACHE_FILE
@ -12,7 +15,8 @@ class SessionCache(object):
@classmethod @classmethod
def _load_cache(cls): def _load_cache(cls):
try: try:
with (get_cache_dir() / SESSION_CACHE_FILE).open("rt") as fp: flag = 'rb' if six.PY2 else 'rt'
with (get_cache_dir() / SESSION_CACHE_FILE).open(flag) as fp:
return json.load(fp) return json.load(fp)
except Exception: except Exception:
return {} return {}
@ -21,7 +25,8 @@ class SessionCache(object):
def _store_cache(cls, cache): def _store_cache(cls, cache):
try: try:
get_cache_dir().mkdir(parents=True, exist_ok=True) get_cache_dir().mkdir(parents=True, exist_ok=True)
with (get_cache_dir() / SESSION_CACHE_FILE).open("wt") as fp: flag = 'wb' if six.PY2 else 'wt'
with (get_cache_dir() / SESSION_CACHE_FILE).open(flag) as fp:
json.dump(cache, fp) json.dump(cache, fp)
except Exception: except Exception:
pass pass

View File

@ -1,3 +1,5 @@
from __future__ import print_function
from pyhocon import ConfigFactory from pyhocon import ConfigFactory
from pathlib2 import Path from pathlib2 import Path
from six.moves.urllib.parse import urlparse, urlunparse from six.moves.urllib.parse import urlparse, urlunparse

View File

@ -4,7 +4,7 @@ import threading
import weakref import weakref
from collections import defaultdict from collections import defaultdict
from logging import ERROR, WARNING, getLogger from logging import ERROR, WARNING, getLogger
from pathlib import Path from pathlib2 import Path
import cv2 import cv2
import numpy as np import numpy as np
@ -99,10 +99,11 @@ class PostImportHookPatching(object):
def _patched_call(original_fn, patched_fn): def _patched_call(original_fn, patched_fn):
def _inner_patch(*args, **kwargs): def _inner_patch(*args, **kwargs):
ident = threading.get_ident() ident = threading._get_ident() if six.PY2 else threading.get_ident()
if ident in _recursion_guard: if ident in _recursion_guard:
return original_fn(*args, **kwargs) return original_fn(*args, **kwargs)
_recursion_guard[ident] = 1 _recursion_guard[ident] = 1
ret = None
try: try:
ret = patched_fn(original_fn, *args, **kwargs) ret = patched_fn(original_fn, *args, **kwargs)
except Exception as ex: except Exception as ex:

View File

@ -2,8 +2,10 @@ import sys
import cv2 import cv2
import numpy as np import numpy as np
import six
from six import BytesIO from six import BytesIO
from ..debugging.log import LoggerRoot
from ..config import running_remotely from ..config import running_remotely
@ -36,6 +38,9 @@ class PatchedMatplotlib:
# we support matplotlib version 2.0.0 and above # we support matplotlib version 2.0.0 and above
import matplotlib import matplotlib
if int(matplotlib.__version__.split('.')[0]) < 2: if int(matplotlib.__version__.split('.')[0]) < 2:
LoggerRoot.get_base_logger().warning(
'matplotlib binding supports version 2.0 and above, found version {}'.format(
matplotlib.__version__))
return False return False
if running_remotely(): if running_remotely():
@ -44,10 +49,13 @@ class PatchedMatplotlib:
import matplotlib.pyplot import matplotlib.pyplot
sys.modules['matplotlib'].pyplot.switch_backend('agg') sys.modules['matplotlib'].pyplot.switch_backend('agg')
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import plotly.tools as tls
from matplotlib import _pylab_helpers from matplotlib import _pylab_helpers
PatchedMatplotlib._patched_original_plot = sys.modules['matplotlib'].pyplot.show if six.PY2:
PatchedMatplotlib._patched_original_imshow = sys.modules['matplotlib'].pyplot.imshow PatchedMatplotlib._patched_original_plot = staticmethod(sys.modules['matplotlib'].pyplot.show)
PatchedMatplotlib._patched_original_imshow = staticmethod(sys.modules['matplotlib'].pyplot.imshow)
else:
PatchedMatplotlib._patched_original_plot = sys.modules['matplotlib'].pyplot.show
PatchedMatplotlib._patched_original_imshow = sys.modules['matplotlib'].pyplot.imshow
sys.modules['matplotlib'].pyplot.show = PatchedMatplotlib.patched_show sys.modules['matplotlib'].pyplot.show = PatchedMatplotlib.patched_show
# sys.modules['matplotlib'].pyplot.imshow = PatchedMatplotlib.patched_imshow # sys.modules['matplotlib'].pyplot.imshow = PatchedMatplotlib.patched_imshow
# patch plotly so we know it failed us. # patch plotly so we know it failed us.
@ -103,7 +111,6 @@ class PatchedMatplotlib:
# noinspection PyBroadException # noinspection PyBroadException
try: try:
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import plotly.tools as tls
from plotly import optional_imports from plotly import optional_imports
from matplotlib import _pylab_helpers from matplotlib import _pylab_helpers
# store the figure object we just created (if it is not already there) # store the figure object we just created (if it is not already there)
@ -136,7 +143,8 @@ class PatchedMatplotlib:
buffer_ = BytesIO() buffer_ = BytesIO()
plt.savefig(buffer_, format="png", bbox_inches='tight', pad_inches=0) plt.savefig(buffer_, format="png", bbox_inches='tight', pad_inches=0)
buffer_.seek(0) buffer_.seek(0)
image = cv2.imdecode(np.frombuffer(buffer_.getbuffer(), dtype=np.uint8), cv2.IMREAD_UNCHANGED) buffer = buffer_.getbuffer() if not six.PY2 else buffer_.getvalue()
image = cv2.imdecode(np.frombuffer(buffer, dtype=np.uint8), cv2.IMREAD_UNCHANGED)
# check if we need to restore the active object # check if we need to restore the active object
if set_active and not _pylab_helpers.Gcf.get_active(): if set_active and not _pylab_helpers.Gcf.get_active():