Updated to tensorflow 2 and added installation commands

This commit is contained in:
danmalowany-allegro 2020-01-05 19:57:50 +02:00 committed by GitHub
parent 54ae340ccb
commit 5abd01fa75
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -3,76 +3,152 @@
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/lib/python3.5/dist-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n",
" from ._conv import register_converters as _register_converters\n",
"Using TensorFlow backend.\n"
]
}
],
"source": [
"# Trains - Example of integrating plots and training on jupyter notebook. \n",
"# In this example, simple graphs are shown, then an MNIST classifier is trained using Keras.\n",
"\n",
"from keras.callbacks import TensorBoard, ModelCheckpoint\n",
"from keras.datasets import mnist\n",
"from keras.models import Sequential\n",
"from keras.layers.core import Dense, Dropout, Activation\n",
"from keras.optimizers import SGD, Adam, RMSprop\n",
"from keras.utils import np_utils\n",
"\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"TRAINS Task: overwriting (reusing) task id=8d23de406d0a4159a496b64c7eba0e32\n",
"======> WARNING! UNCOMMITTED CHANGES IN REPOSITORY https://github.com/allegroai/trains.git <======\n",
"TRAINS results page: https://demoapp.trainsai.io/projects/087f765c846c4c76a7e9f3d035667d82/experiments/8d23de406d0a4159a496b64c7eba0e32/output/log\n"
"Requirement already up-to-date: trains in /home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages (0.12.2)\n",
"Requirement already satisfied, skipping upgrade: attrs>=18.0 in /home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages (from trains) (19.3.0)\n",
"Requirement already satisfied, skipping upgrade: urllib3>=1.21.1 in /home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages (from trains) (1.25.7)\n",
"Requirement already satisfied, skipping upgrade: typing>=3.6.4 in /home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages (from trains) (3.6.4)\n",
"Requirement already satisfied, skipping upgrade: funcsigs>=1.0 in /home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages (from trains) (1.0.2)\n",
"Requirement already satisfied, skipping upgrade: PyYAML>=3.12 in /home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages (from trains) (5.2)\n",
"Requirement already satisfied, skipping upgrade: pyjwt>=1.6.4 in /home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages (from trains) (1.7.1)\n",
"Requirement already satisfied, skipping upgrade: plotly>=3.9.0 in /home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages (from trains) (4.2.1)\n",
"Requirement already satisfied, skipping upgrade: requests>=2.20.0 in /home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages (from trains) (2.22.0)\n",
"Requirement already satisfied, skipping upgrade: jsonschema>=2.6.0 in /home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages (from trains) (3.2.0)\n",
"Requirement already satisfied, skipping upgrade: requests-file>=1.4.2 in /home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages (from trains) (1.4.3)\n",
"Requirement already satisfied, skipping upgrade: pigar==0.9.2 in /home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages (from trains) (0.9.2)\n",
"Requirement already satisfied, skipping upgrade: psutil>=3.4.2 in /home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages (from trains) (5.6.7)\n",
"Requirement already satisfied, skipping upgrade: Pillow>=4.1.1 in /home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages (from trains) (6.2.1)\n",
"Requirement already satisfied, skipping upgrade: numpy>=1.10 in /home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages (from trains) (1.18.0)\n",
"Requirement already satisfied, skipping upgrade: pathlib2>=2.3.0 in /home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages (from trains) (2.3.5)\n",
"Requirement already satisfied, skipping upgrade: jsonmodels>=2.2 in /home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages (from trains) (2.4)\n",
"Requirement already satisfied, skipping upgrade: future>=0.16.0 in /home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages (from trains) (0.18.2)\n",
"Requirement already satisfied, skipping upgrade: pyparsing>=2.0.3 in /home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages (from trains) (2.4.5)\n",
"Requirement already satisfied, skipping upgrade: tqdm>=4.19.5 in /home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages (from trains) (4.41.1)\n",
"Requirement already satisfied, skipping upgrade: furl>=2.0.0 in /home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages (from trains) (2.1.0)\n",
"Requirement already satisfied, skipping upgrade: humanfriendly>=2.1 in /home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages (from trains) (4.18)\n",
"Requirement already satisfied, skipping upgrade: six>=1.11.0 in /home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages (from trains) (1.13.0)\n",
"Requirement already satisfied, skipping upgrade: python-dateutil>=2.6.1 in /home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages (from trains) (2.8.0)\n",
"Requirement already satisfied, skipping upgrade: retrying>=1.3.3 in /home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages (from plotly>=3.9.0->trains) (1.3.3)\n",
"Requirement already satisfied, skipping upgrade: certifi>=2017.4.17 in /home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages (from requests>=2.20.0->trains) (2019.11.28)\n",
"Requirement already satisfied, skipping upgrade: chardet<3.1.0,>=3.0.2 in /home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages (from requests>=2.20.0->trains) (3.0.4)\n",
"Requirement already satisfied, skipping upgrade: idna<2.9,>=2.5 in /home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages (from requests>=2.20.0->trains) (2.8)\n",
"Requirement already satisfied, skipping upgrade: pyrsistent>=0.14.0 in /home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages (from jsonschema>=2.6.0->trains) (0.15.6)\n",
"Requirement already satisfied, skipping upgrade: setuptools in /home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages (from jsonschema>=2.6.0->trains) (44.0.0)\n",
"Requirement already satisfied, skipping upgrade: importlib-metadata; python_version < \"3.8\" in /home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages (from jsonschema>=2.6.0->trains) (1.3.0)\n",
"Requirement already satisfied, skipping upgrade: colorama>=0.3.9 in /home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages (from pigar==0.9.2->trains) (0.4.3)\n",
"Requirement already satisfied, skipping upgrade: orderedmultidict>=1.0.1 in /home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages (from furl>=2.0.0->trains) (1.0.1)\n",
"Requirement already satisfied, skipping upgrade: zipp>=0.5 in /home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages (from importlib-metadata; python_version < \"3.8\"->jsonschema>=2.6.0->trains) (0.6.0)\n",
"Requirement already satisfied, skipping upgrade: more-itertools in /home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages (from zipp>=0.5->importlib-metadata; python_version < \"3.8\"->jsonschema>=2.6.0->trains) (8.0.2)\n",
"Requirement already up-to-date: numpy==1.18.0 in /home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages (1.18.0)\n",
"Requirement already up-to-date: tensorflow==2.0.0 in /home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages (2.0.0)\n",
"Requirement already satisfied, skipping upgrade: keras-applications>=1.0.8 in /home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages (from tensorflow==2.0.0) (1.0.8)\n",
"Requirement already satisfied, skipping upgrade: gast==0.2.2 in /home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages (from tensorflow==2.0.0) (0.2.2)\n",
"Requirement already satisfied, skipping upgrade: opt-einsum>=2.3.2 in /home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages (from tensorflow==2.0.0) (3.1.0)\n",
"Requirement already satisfied, skipping upgrade: tensorflow-estimator<2.1.0,>=2.0.0 in /home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages (from tensorflow==2.0.0) (2.0.1)\n",
"Requirement already satisfied, skipping upgrade: grpcio>=1.8.6 in /home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages (from tensorflow==2.0.0) (1.26.0)\n",
"Requirement already satisfied, skipping upgrade: six>=1.10.0 in /home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages (from tensorflow==2.0.0) (1.13.0)\n",
"Requirement already satisfied, skipping upgrade: termcolor>=1.1.0 in /home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages (from tensorflow==2.0.0) (1.1.0)\n",
"Requirement already satisfied, skipping upgrade: astor>=0.6.0 in /home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages (from tensorflow==2.0.0) (0.8.1)\n",
"Requirement already satisfied, skipping upgrade: wheel>=0.26 in /home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages (from tensorflow==2.0.0) (0.33.6)\n",
"Requirement already satisfied, skipping upgrade: protobuf>=3.6.1 in /home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages (from tensorflow==2.0.0) (3.11.2)\n",
"Requirement already satisfied, skipping upgrade: google-pasta>=0.1.6 in /home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages (from tensorflow==2.0.0) (0.1.8)\n",
"Requirement already satisfied, skipping upgrade: wrapt>=1.11.1 in /home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages (from tensorflow==2.0.0) (1.11.2)\n",
"Requirement already satisfied, skipping upgrade: numpy<2.0,>=1.16.0 in /home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages (from tensorflow==2.0.0) (1.18.0)\n",
"Requirement already satisfied, skipping upgrade: absl-py>=0.7.0 in /home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages (from tensorflow==2.0.0) (0.9.0)\n",
"Requirement already satisfied, skipping upgrade: tensorboard<2.1.0,>=2.0.0 in /home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages (from tensorflow==2.0.0) (2.0.0)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied, skipping upgrade: keras-preprocessing>=1.0.5 in /home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages (from tensorflow==2.0.0) (1.1.0)\n",
"Requirement already satisfied, skipping upgrade: h5py in /home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages (from keras-applications>=1.0.8->tensorflow==2.0.0) (2.8.0)\n",
"Requirement already satisfied, skipping upgrade: setuptools in /home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages (from protobuf>=3.6.1->tensorflow==2.0.0) (44.0.0)\n",
"Requirement already satisfied, skipping upgrade: werkzeug>=0.11.15 in /home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages (from tensorboard<2.1.0,>=2.0.0->tensorflow==2.0.0) (0.16.0)\n",
"Requirement already satisfied, skipping upgrade: markdown>=2.6.8 in /home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages (from tensorboard<2.1.0,>=2.0.0->tensorflow==2.0.0) (3.1.1)\n",
"Requirement already up-to-date: tensorboard==2.0.0 in /home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages (2.0.0)\n",
"Requirement already satisfied, skipping upgrade: markdown>=2.6.8 in /home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages (from tensorboard==2.0.0) (3.1.1)\n",
"Requirement already satisfied, skipping upgrade: werkzeug>=0.11.15 in /home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages (from tensorboard==2.0.0) (0.16.0)\n",
"Requirement already satisfied, skipping upgrade: absl-py>=0.4 in /home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages (from tensorboard==2.0.0) (0.9.0)\n",
"Requirement already satisfied, skipping upgrade: wheel>=0.26; python_version >= \"3\" in /home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages (from tensorboard==2.0.0) (0.33.6)\n",
"Requirement already satisfied, skipping upgrade: grpcio>=1.6.3 in /home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages (from tensorboard==2.0.0) (1.26.0)\n",
"Requirement already satisfied, skipping upgrade: numpy>=1.12.0 in /home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages (from tensorboard==2.0.0) (1.18.0)\n",
"Requirement already satisfied, skipping upgrade: protobuf>=3.6.0 in /home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages (from tensorboard==2.0.0) (3.11.2)\n",
"Requirement already satisfied, skipping upgrade: six>=1.10.0 in /home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages (from tensorboard==2.0.0) (1.13.0)\n",
"Requirement already satisfied, skipping upgrade: setuptools>=41.0.0 in /home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages (from tensorboard==2.0.0) (44.0.0)\n"
]
}
],
"source": [
"# Trains - Example of integrating plots and training on jupyter notebook. \n",
"# In this example, simple graphs are shown, then an MNIST classifier is trained using Keras.!pip install -U pip\n",
"!pip install -U trains\n",
"!pip install -U numpy==1.18.0\n",
"!pip install -U tensorflow==2.0.0\n",
"!pip install -U tensorboard==2.0.0"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from tensorflow.keras.callbacks import TensorBoard, ModelCheckpoint\n",
"from tensorflow.keras.datasets import mnist\n",
"from tensorflow.keras.models import Sequential\n",
"from tensorflow.keras.layers import Dense, Dropout, Activation\n",
"from tensorflow.keras.optimizers import SGD, Adam, RMSprop\n",
"from tensorflow.keras.utils import to_categorical\n",
"\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"TRAINS Task: created new task id=83ebee675e9f4b50af88da70be6a30d6\n",
"2020-01-05 17:51:21,909 - trains.Task - INFO - No repository found, storing script code instead\n",
"TRAINS results page: https://demoapp.trains.allegro.ai/projects/087f765c846c4c76a7e9f3d035667d82/experiments/83ebee675e9f4b50af88da70be6a30d6/output/log\n"
]
}
],
"source": [
"# Connecting TRAINS\n",
"from trains import Task\n",
"task = Task.init(project_name = 'examples', task_name = 'notebook example')\n"
"task = Task.init(project_name = 'examples', task_name = 'notebook example')"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"scrolled": true
},
"outputs": [],
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"TRAINS Monitor: GPU monitoring failed getting GPU reading, switching off GPU monitoring\n"
]
}
],
"source": [
"# Set script parameters\n",
"task_params = {'num_scatter_samples': 60, 'sin_max_value': 20, 'sin_steps': 30}\n",
"task_params = task.connect(task_params)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Simple plots. You can view the plots in experiments results page "
"task_params = task.connect(task_params)"
]
},
{
@ -140,19 +216,10 @@
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Keras training example\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"scrolled": true
},
"metadata": {},
"outputs": [],
"source": [
"# Notice, Updating task_params is traced and updated in TRAINS\n",
@ -162,17 +229,13 @@
"task_params['hidden_dim'] = 512\n",
"batch_size = task_params['batch_size']\n",
"nb_classes = task_params['nb_classes']\n",
"nb_epoch = task_params['nb_epoch']\n"
"nb_epoch = task_params['nb_epoch']"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"metadata": {},
"outputs": [
{
"name": "stdout",
@ -180,46 +243,59 @@
"text": [
"60000 train samples\n",
"10000 test samples\n",
"WARNING:tensorflow:From /usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"Colocations handled automatically by placer.\n",
"Model: \"sequential\"\n",
"_________________________________________________________________\n",
"Layer (type) Output Shape Param # \n",
"=================================================================\n",
"dense_1 (Dense) (None, 512) 401920 \n",
"dense (Dense) (None, 512) 401920 \n",
"_________________________________________________________________\n",
"activation (Activation) (None, 512) 0 \n",
"_________________________________________________________________\n",
"dense_1 (Dense) (None, 512) 262656 \n",
"_________________________________________________________________\n",
"activation_1 (Activation) (None, 512) 0 \n",
"_________________________________________________________________\n",
"dense_2 (Dense) (None, 512) 262656 \n",
"dense_2 (Dense) (None, 10) 5130 \n",
"_________________________________________________________________\n",
"activation_2 (Activation) (None, 512) 0 \n",
"_________________________________________________________________\n",
"dense_3 (Dense) (None, 10) 5130 \n",
"_________________________________________________________________\n",
"activation_3 (Activation) (None, 10) 0 \n",
"activation_2 (Activation) (None, 10) 0 \n",
"=================================================================\n",
"Total params: 669,706\n",
"Trainable params: 669,706\n",
"Non-trainable params: 0\n",
"_________________________________________________________________\n",
"WARNING:tensorflow:From /usr/local/lib/python3.5/dist-packages/tensorflow/python/ops/math_ops.py:3066: to_int32 (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"Use tf.cast instead.\n",
"Train on 60000 samples, validate on 10000 samples\n",
"Epoch 1/6\n",
"60000/60000 [==============================] - 4s 64us/step - loss: 0.2129 - acc: 0.9350 - val_loss: 0.1012 - val_acc: 0.9682\n",
"59264/60000 [============================>.] - ETA: 0s - loss: 0.2217 - accuracy: 0.93122020-01-05 17:51:43,884 - trains - WARNING - too many indices for array\n",
"2020-01-05 17:51:43,906 - trains - WARNING - too many indices for array\n",
"2020-01-05 17:51:43,909 - trains - WARNING - too many indices for array\n",
"60000/60000 [==============================] - 6s 99us/sample - loss: 0.2206 - accuracy: 0.9316 - val_loss: 0.1300 - val_accuracy: 0.9586\n",
"Epoch 2/6\n",
"60000/60000 [==============================] - 4s 68us/step - loss: 0.0813 - acc: 0.9752 - val_loss: 0.0684 - val_acc: 0.9779\n",
"59264/60000 [============================>.] - ETA: 0s - loss: 0.0808 - accuracy: 0.97512020-01-05 17:51:49,017 - trains - WARNING - too many indices for array\n",
"2020-01-05 17:51:49,039 - trains - WARNING - too many indices for array\n",
"2020-01-05 17:51:49,043 - trains - WARNING - too many indices for array\n",
"60000/60000 [==============================] - 5s 84us/sample - loss: 0.0804 - accuracy: 0.9752 - val_loss: 0.0794 - val_accuracy: 0.9765\n",
"Epoch 3/6\n",
"60000/60000 [==============================] - 4s 62us/step - loss: 0.0540 - acc: 0.9830 - val_loss: 0.0736 - val_acc: 0.9793\n",
"59648/60000 [============================>.] - ETA: 0s - loss: 0.0542 - accuracy: 0.98342020-01-05 17:51:54,222 - trains - WARNING - too many indices for array\n",
"2020-01-05 17:51:54,245 - trains - WARNING - too many indices for array\n",
"2020-01-05 17:51:54,248 - trains - WARNING - too many indices for array\n",
"60000/60000 [==============================] - 5s 87us/sample - loss: 0.0540 - accuracy: 0.9834 - val_loss: 0.0758 - val_accuracy: 0.9782\n",
"Epoch 4/6\n",
"60000/60000 [==============================] - 4s 64us/step - loss: 0.0387 - acc: 0.9880 - val_loss: 0.0859 - val_acc: 0.9761\n",
"59392/60000 [============================>.] - ETA: 0s - loss: 0.0388 - accuracy: 0.98762020-01-05 17:51:59,298 - trains - WARNING - too many indices for array\n",
"2020-01-05 17:51:59,320 - trains - WARNING - too many indices for array\n",
"2020-01-05 17:51:59,324 - trains - WARNING - too many indices for array\n",
"60000/60000 [==============================] - 5s 84us/sample - loss: 0.0387 - accuracy: 0.9876 - val_loss: 0.0836 - val_accuracy: 0.9777\n",
"Epoch 5/6\n",
"60000/60000 [==============================] - 4s 63us/step - loss: 0.0304 - acc: 0.9904 - val_loss: 0.0875 - val_acc: 0.9766\n",
"59520/60000 [============================>.] - ETA: 0s - loss: 0.0282 - accuracy: 0.99142020-01-05 17:52:04,410 - trains - WARNING - too many indices for array\n",
"2020-01-05 17:52:04,433 - trains - WARNING - too many indices for array\n",
"2020-01-05 17:52:04,436 - trains - WARNING - too many indices for array\n",
"60000/60000 [==============================] - 5s 85us/sample - loss: 0.0280 - accuracy: 0.9915 - val_loss: 0.0754 - val_accuracy: 0.9811\n",
"Epoch 6/6\n",
"60000/60000 [==============================] - 4s 64us/step - loss: 0.0220 - acc: 0.9933 - val_loss: 0.0847 - val_acc: 0.9793\n",
"Test score: 0.08471047468512916\n",
"Test accuracy: 0.9793\n"
"59520/60000 [============================>.] - ETA: 0s - loss: 0.0242 - accuracy: 0.99252020-01-05 17:52:09,482 - trains - WARNING - too many indices for array\n",
"2020-01-05 17:52:09,504 - trains - WARNING - too many indices for array\n",
"2020-01-05 17:52:09,507 - trains - WARNING - too many indices for array\n",
"60000/60000 [==============================] - 5s 85us/sample - loss: 0.0243 - accuracy: 0.9924 - val_loss: 0.0769 - val_accuracy: 0.9824\n",
"Test score: 0.07691321085649504\n",
"Test accuracy: 0.9824\n"
]
}
],
@ -237,8 +313,8 @@
"print(X_test.shape[0], 'test samples')\n",
"\n",
"# convert class vectors to binary class matrices\n",
"Y_train = np_utils.to_categorical(y_train, nb_classes)\n",
"Y_test = np_utils.to_categorical(y_test, nb_classes)\n",
"Y_train = to_categorical(y_train, nb_classes)\n",
"Y_test = to_categorical(y_test, nb_classes)\n",
"\n",
"hidden_dim = task_params['hidden_dim']\n",
"model = Sequential()\n",
@ -266,15 +342,22 @@
" verbose=1, validation_data=(X_test, Y_test))\n",
"score = model.evaluate(X_test, Y_test, verbose=0)\n",
"print('Test score:', score[0])\n",
"print('Test accuracy:', score[1])\n"
"print('Test accuracy:', score[1])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": "conda_python3",
"language": "python",
"name": "python3"
"name": "conda_python3"
},
"language_info": {
"codemirror_mode": {
@ -286,16 +369,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.2"
},
"pycharm": {
"stem_cell": {
"cell_type": "raw",
"metadata": {
"collapsed": false
},
"source": []
}
"version": "3.6.9"
}
},
"nbformat": 4,