mirror of
https://github.com/clearml/clearml
synced 2025-01-31 17:17:00 +00:00
300 lines
83 KiB
Plaintext
300 lines
83 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Trains - Example of integrating plots and training on jupyter notebook. \n",
|
|
"# In this example, simple graphs are shown, then an MNIST classifier is trained using Keras. \n",
|
|
"!pip install -U pip\n",
|
|
"!pip install -U trains\n",
|
|
"!pip install -U numpy==1.18.0\n",
|
|
"!pip install -U tensorflow==2.0.0\n",
|
|
"!pip install -U tensorboard==2.0.0\n",
|
|
"!pip install -U matplotlib==3.1.2"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from tensorflow.keras.callbacks import TensorBoard, ModelCheckpoint\n",
|
|
"from tensorflow.keras.datasets import mnist\n",
|
|
"from tensorflow.keras.models import Sequential\n",
|
|
"from tensorflow.keras.layers import Dense, Dropout, Activation\n",
|
|
"from tensorflow.keras.optimizers import SGD, Adam, RMSprop\n",
|
|
"from tensorflow.keras.utils import to_categorical\n",
|
|
"\n",
|
|
"import numpy as np\n",
|
|
"import matplotlib.pyplot as plt"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"TRAINS Task: created new task id=83ebee675e9f4b50af88da70be6a30d6\n",
|
|
"2020-01-05 17:51:21,909 - trains.Task - INFO - No repository found, storing script code instead\n",
|
|
"TRAINS results page: https://demoapp.trains.allegro.ai/projects/087f765c846c4c76a7e9f3d035667d82/experiments/83ebee675e9f4b50af88da70be6a30d6/output/log\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# Connecting TRAINS\n",
|
|
"from trains import Task\n",
|
|
"task = Task.init(project_name = 'examples', task_name = 'notebook example')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"TRAINS Monitor: GPU monitoring failed getting GPU reading, switching off GPU monitoring\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# Set script parameters\n",
|
|
"task_params = {'num_scatter_samples': 60, 'sin_max_value': 20, 'sin_steps': 30}\n",
|
|
"task_params = task.connect(task_params)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"image/png": "\n",
|
|
"text/plain": [
|
|
"<Figure size 432x288 with 1 Axes>"
|
|
]
|
|
},
|
|
"metadata": {
|
|
"needs_background": "light"
|
|
},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYYAAAEICAYAAABbOlNNAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAHetJREFUeJzt3X2QHHd95/H3Z63Y3PAo2SohW9KMDeb5EoMWHwQCJPgJirNMQoy4pZB5uC2T+O5ykLuI2zri89UmBuoCRY4jbMCxyU5hEXOcdRyU8QMuuFzkeE1kYxtsyWJ3LSHbG2QbzBJjS9/7o3vF9Ghmd0bz0PPweVV1bfevfz39VU/PfNX9+03/FBGYmZktGck7ADMz6y1ODGZmluHEYGZmGU4MZmaW4cRgZmYZTgxmZpbhxGBDRdKYpG/mHYdZL3NisIEj6fWS/p+kxyUdkvS3kl4NEBHliDivy/FcLukpST9Np/sl/XdJ65t4jZD0wk7GabbEicEGiqTnAF8D/hxYA5wG/BfgyTzjAnZExLNJYno78HzgjmaSg1m3ODHYoHkRQER8KSIOR8TPI+KbEXEXgKRLJP3fpcrp/8QvlbRH0mOSPiNJ6brLJU1X1C2l9VdVvNa+9Crgh5LGVgouIp6KiHuAdwILwIcrXv9fS9qbXuXslHRqWv7ttMqdkp6Q9E5Jp0j6WhrzIUnfkeTPs7WFTyQbNPcDhyVdI+ktklY3sM3bgFcDvwpcDJy/0gaSngl8GnhLeiXw68DuRoOMiMPA9cBvpK/3W8CfpvtfD8wB16Z135Bu9msR8ayI2EGSUPYDa4F1wH8C/HwbawsnBhsoEfET4PUkX5J/CSyk//tet8xmV0bEYxExD3wLOKvB3R0BXiHpn0XEwfRKoBk/Irm1BDAGXBUR342IJ4GPAK+VVKqz7VMkCaSYXoV8J/zgM2sTJwYbOBHx/Yi4JCI2AK8ATgU+tcwmD1XMLwLPamAfPyO5HXQpcFDS/5H0kiZDPQ04lM6fSnKVsPT6TwA/TuvU8glgL/DN9HbW9ib3bVaXE4MNtIj4AXA1SYJo1s+AQsXy86te+4aIOJfkf+4/ILlCaUjaHvAvge+kRT8CihXrnwmcDByotX1E/DQiPhwRZwAXAh+S9OZG92+2HCcGGyiSXiLpw5I2pMsbgXcBu47j5XYDb5C0SdJzSW7vLO1nnaQt6Rf4k8ATJLeWVopvlaSXAl8iSTR/lq76EvBeSWdJOgn4E+C2iJhN1z8MnFHxOm+T9MK0ofxx4HAj+zdrhBODDZqfAv8CuE3Sz0gSwt1U9P5pVETcCOwA7gLuIOkGu2QE+BDJ//QPAW8EPrjMy71T0hMkX+I7SW4TbY6IH6X7ugn4z8BXgIPAC4CtFdtfDlyT9kK6GDgTuIkkIf0d8D8i4lvN/hvNapHbq8zMrJKvGMzMLMOJwczMMpwYzMwsw4nBzMwyVuUdwPE45ZRTolQq5R2GmVlfueOOO/4xItauVK8vE0OpVGJmZibvMMzM+oqkuZVr+VaSmZlVcWIwM7MMJwYzM8toS2KQdJWkRyTdXWe9JH06HYTkLkmvqli3LR0kZY+kbe2Ix8zMjl+7rhiuBi5YZv1bSJ7tciYwDnwWQNIa4I9Jnm1zNvDHDQ6sYmZmHdKWxBAR3+aXz5WvZQvwxUjsAp6XjnV7PnBjRByKiEeBG1k+wfSUcrlMqVRiZGSEUqlEuVzOOyQzs5Z1q7vqacCDFcv707J65ceQNE5ytcGmTZs6E2UTyuUy4+PjLC4uAjA3N8f4+DgAY2MrDv1rZtaz+qbxOSKmImI0IkbXrl3x9xkdNzExcTQpLFlcXGRiYiKniIaHr9TMOqtbieEAsLFieUNaVq+8583PzzdVbu2xdKU2NzdHRBy9UnNyMGufbiWGncB70t5JrwEej4iDwA3AeZJWp43O56VlPa/e7axeuM01yHyllq9Gr9Z8VdfnIqLliWRYwoPAUyTtBO8nGST90nS9gM8ADwDfA0Yrtn0fyaDme4H3NrK/zZs3R96mp6ejUCgEcHQqFAoxPT2dd2gDTVLmmC9NkvIObeA1es77s9G7gJlo5Du9kUq9NvVCYohIPgDFYjEkRbFY9InfBcVisWZiKBaLeYc28Bo99n6PelejiaFvGp970djYGLOzsxw5coTZ2Vn3RuqCyclJCoVCpqxQKDA5OVmzvm9ptE+j7Wpuf2u/rp/HjWSPXpt65YrB8tHolZpvabSXrxjy0c7zGN9KsmHnL6j2chtDPtp5HjeaGHwryQaWb2m019jYGFNTUxSLRSRRLBaZmpo65hZqo/WsMXmcx04MXeJ73d3nLsXt12i7mtvf2ieP89iJoQv8o6x8NNtQbdaLcjmPG7nf1GtTv7Ux+F53ftyl2AZBu85jGmxjUFK3v4yOjkY/jfk8MjJCreMsiSNHjuQQkZkNI0l3RMToSvV8K6kLfK/bzPqJE0MX+F63mfUTJ4YucPc9M+snbmMwMxsSbmMwM7Pj4sRgZmYZTgxmZpbhxGBmZhltSQySLpB0n6S9krbXWP9JSbvT6X5Jj1WsO1yxbmc74jEzs+O3qtUXkHQCybCd55IM63m7pJ0Rce9SnYj49xX1/w3wyoqX+HlEnNVqHGZm1h7tuGI4G9gbEfsi4hfAtcCWZeq/i2SMaDMz60HtSAynAQ9WLO9Py44hqQicDtxSUfwMSTOSdkm6qN5OJI2n9WYWFhbaELaZmdXS7cbnrcB1EXG4oqyY/uDiXwGfkvSCWhtGxFREjEbE6Nq1a7sRq5nZUGpHYjgAbKxY3pCW1bKVqttIEXEg/bsPuJVs+4OZmXVZOxLD7cCZkk6XdCLJl/8xvYskvQRYDfxdRdlqSSel86cArwPurd7WzMy6p+VeSRHxtKTLgBuAE4CrIuIeSVeQDAqxlCS2AtdG9uFMLwU+J+kISZK6srI3k5mZdZ8fomdmNiT8ED0zMzsuTgxmZpbhxGBmZhlODDWUy2VKpRIjIyOUSiXK5XLeIZmZdU3LvZIGTblcZnx8nMXFRQDm5uYYHx8H8FCcZjYUfMVQZWJi4mhSWLK4uMjExEROEZmZdZcTQ5X5+fmmys3MBo0TQ5VNmzY1VW5mNmicGKpMTk5SKBQyZYVCgcnJyZwiMjPrLieGKmNjY0xNTVEsFpFEsVhkamrKDc9d4N5gZr3Bj8SwnlDdGwySKzUnZbP28SMxrK+4N5hZ73BisJ7g3mBmvcOJwXqCe4OZ9Q4nBusJ7g1m1jucGKwnuDdYfvLuDZb3/q2GiGh5Ai4A7gP2AttrrL8EWAB2p9MHKtZtA/ak07ZG9rd58+Yws9ZNT09HoVAI4OhUKBRienp6KPY/bEhG1VzxO7bl7qqSTgDuB84F9pOMAf2uqBiiU9IlwGhEXFa17RpgBhhNT4o7gM0R8ehy+3R3VbP2KJVKzM3NHVNeLBaZnZ0d+P0Pm252Vz0b2BsR+yLiF8C1wJYGtz0fuDEiDqXJ4EaSqw8z64K8e4PlvX+rrR2J4TTgwYrl/WlZtd+RdJek6yRtbHJbJI1LmpE0s7Cw0IawzSzv3mB5799q61bj8/8GShHxqyRXBdc0+wIRMRURoxExunbt2rYHaDaM8u4Nlvf+rbZ2JIYDwMaK5Q1p2VER8eOIeDJd/DywudFtzaxz8u4Nlvf+rY5GWqiXm0hGgdsHnA6cCNwJvLyqzvqK+bcDu9L5NcAPgdXp9ENgzUr7dK8ka7fp6ekoFoshKYrFonvFWFd0+7yjwV5JLQ/tGRFPS7oMuAE4AbgqIu6RdEUaxE7g30q6EHgaOETSfZWIOCTpv5L0ZAK4IiIOtRqTWTM8nKvloZfPOz9d1Yaeu0xaHvI47/x0VbMGucuk5aGXzzsnBht67jJpeejl886JwYaeu0xaHnr5vHNisKHnLpOWh14+79z43GPK5TITExPMz8+zadMmJicne+JEMbP+12jjc8vdVa19ern7mpkND99K6iEe99jMeoETQw/p5e5rZjY8nBh6SC93XzOz4eHE0EN6ufuamQ0PJ4Ye0svd18xseLi7qpnZkPCzkszM7Lg4MZiZWYYTg5mZZTgxmJlZRlsSg6QLJN0naa+k7TXWf0jSvZLuknSzpGLFusOSdqfTznbEY2Zmx6/lZyVJOgH4DHAusB+4XdLOiLi3oto/AKMRsSjpg8DHgXem634eEWe1GoeZmbVHO64Yzgb2RsS+iPgFcC2wpbJCRHwrIpYeArQL2NCG/ZqZWQe0IzGcBjxYsbw/Lavn/cA3KpafIWlG0i5JF9XbSNJ4Wm9mYWGhtYjNzKyurj52W9K7gVHgjRXFxYg4IOkM4BZJ34uIB6q3jYgpYAqSH7h1JWAzsyHUjiuGA8DGiuUNaVmGpHOACeDCiHhyqTwiDqR/9wG3Aq9sQ0xmZnac2pEYbgfOlHS6pBOBrUCmd5GkVwKfI0kKj1SUr5Z0Ujp/CvA6oLLR2szMuqzlW0kR8bSky4AbgBOAqyLiHklXADMRsRP4BPAs4G8kAcxHxIXAS4HPSTpCkqSurOrNZGZmXeaH6JmZDQk/RM/MzI6LE4OZmWU4MZiZWYYTg5mZZTgxmJlZhhODmZllODGYmVmGE4OZmWU4MZiZWYYTg5mZZTgxmJlZhhODmZllODGYmVmGE4OZmWU4MZiZWYYTg5mZZbQlMUi6QNJ9kvZK2l5j/UmSdqTrb5NUqlj3kbT8PknntyOeWsrlMqVSiZGREUqlEuVyuVO7MjPray0P7SnpBOAzwLnAfuB2STurhuh8P/BoRLxQ0lbgY8A7Jb2MZIzolwOnAjdJelFEHG41rkrlcpnx8XEWFxcBmJubY3x8HICxsbF27srMrO+144rhbGBvROyLiF8A1wJbqupsAa5J568D3qxk8OctwLUR8WRE/BDYm75eW01MTBxNCksWFxeZmJho967MzPpeOxLDacCDFcv707KadSLiaeBx4OQGtwVA0rikGUkzCwsLTQU4Pz/fVLmZ2TDrm8bniJiKiNGIGF27dm1T227atKmpcjOzYdaOxHAA2FixvCEtq1lH0irgucCPG9y2ZZOTkxQKhUxZoVBgcnKy3bsyM+t77UgMtwNnSjpd0okkjck7q+rsBLal8+8AbomISMu3pr2WTgfOBP6+DTFljI2NMTU1RbFYRBLFYpGpqSk3PJuZ1dByr6SIeFrSZcANwAnAVRFxj6QrgJmI2Al8AfhrSXuBQyTJg7Tel4F7gaeB3293j6QlY2NjTgRmZg1oSxtDRHw9Il4UES+IiMm07KNpUiAi/ikifjciXhgRZ0fEvoptJ9PtXhwR32hHPNZb/BsSs/7S8hWD2XL8GxKz/tM3vZKsP/k3JGb9x4nBOsq/ITHrP04M1lH+DYlZ/3FisI7yb0jM+o8Tg3WUf0Ni1n+cGKzjxsbGmJ2d5ciRI8zOzjopdIm7CdvxcmIwG0BL3YTn5uaIiKPdhPs5OTjRdY+SJ1P0l9HR0ZiZmck7DLOeVSqVmJubO6a8WCwyOzvb/YBaVP17GEjaqnxbsjmS7oiI0RXrOTGYDZ6RkRFqfbYlceTIkRwias2gJbq8NJoYfCvJbAANWjdh/x6mu5wYzAbQoHUTHrRE1+ucGMwG0KB1Ex60RNfr3MZgZn2hXC4zMTHB/Pw8mzZtYnJysm8TXV7cxmDWIe42mY9++T3MIJwffuy2WRP8GHFbzqCcHy3dSpK0BtgBlIBZ4OKIeLSqzlnAZ4HnAIeByYjYka67Gngj8Hha/ZKI2L3Sfn0ryfLibpO2nF4/P7p1K2k7cHNEnAncnC5XWwTeExEvBy4APiXpeRXr/0NEnJVOKyYFszy526QtZ1DOj1YTwxbgmnT+GuCi6goRcX9E7EnnfwQ8Aqxtcb9muXC3SVvOoJwfrSaGdRFxMJ1/CFi3XGVJZwMnAg9UFE9KukvSJyWdtMy245JmJM0sLCy0GLbZ8XG3SVvOwJwfEbHsBNwE3F1j2gI8VlX30WVeZz1wH/CaqjIBJ5FccXx0pXgigs2bN4dZXqanp6NYLIakKBaLMT09nXdI1kN6+fwAZqKB79hWG5/vA94UEQclrQdujYgX16j3HOBW4E8i4ro6r/Um4A8j4m0r7deNzwn36zazZnSr8XknsC2d3wZcXyOQE4GvAl+sTgppMkGSSNon7m4xnqExiI9VNrPe0OoVw8nAl4FNwBxJd9VDkkaBSyPiA5LeDfwVcE/FppdExG5Jt5A0RAvYnW7zxEr79RVD73eLM7Pe48duD7hBe6yymXWeH4kx4AalW5yZ9R4nhj41MN3izKznODH0qUF7rLKZ9Q63MZiZDQm3MZiZ2XFxYjAzswwnBjMzy3BiMDOzDCcGMzPLcGIwM7MMJwYzM8twYjAzswwnBjMzy3BiMDOzDCcGMzPLcGIwM7OMlhKDpDWSbpS0J/27uk69w5J2p9POivLTJd0maa+kHekwoGZmlqNWrxi2AzdHxJnAzelyLT+PiLPS6cKK8o8Bn4yIFwKPAu9vMR4zM2tRq4lhC3BNOn8NcFGjG0oS8FvAdcezvZmZdUariWFdRBxM5x8C1tWp9wxJM5J2SVr68j8ZeCwink6X9wOn1duRpPH0NWYWFhZaDNvMzOpZtVIFSTcBz6+xaqJyISJCUr1Rf4oRcUDSGcAtkr4HPN5MoBExBUxBMlBPM9uamVnjVkwMEXFOvXWSHpa0PiIOSloPPFLnNQ6kf/dJuhV4JfAV4HmSVqVXDRuAA8fxbzAzszZq9VbSTmBbOr8NuL66gqTVkk5K508BXgfcG8mYot8C3rHc9mZm1l2tJoYrgXMl7QHOSZeRNCrp82mdlwIzku4kSQRXRsS96bo/Aj4kaS9Jm8MXWozHzMxapOQ/7v1ldHQ0ZmZm8g7DzKyvSLojIkZXqudfPpuZWYYTg5mZZTgxmJlZhhODmZllODGYmVmGE4OZmWU4MZiZWYYTg5mZZTgxmJlZhhODmZllODHYcSmXy5RKJUZGRiiVSpTL5bxDMrM2WfGx22bVyuUy4+PjLC4uAjA3N8f4+DgAY2NjeYZmZm3gKwZr2sTExNGksGRxcZGJiYk6W5hZP3FisKbNz883VW5m/cWJwZq2adOmpsrNrL84MVjTJicnKRQKmbJCocDk5GROEZlZO7WUGCStkXSjpD3p39U16vympN0V0z9Juihdd7WkH1asO6uVeKw7xsbGmJqaolgsIoliscjU1JQbnrvAvcEa4+PUoog47gn4OLA9nd8OfGyF+muAQ0AhXb4aeEez+928eXOYDZvp6ekoFAoBHJ0KhUJMT0/nHVpP8XGqD5iJBr5jWxraU9J9wJsi4qCk9cCtEfHiZeqPA2+MiLF0+WrgaxFxXTP79dCeNoxKpRJzc3PHlBeLRWZnZ7sfUI/ycaqvW0N7rouIg+n8Q8C6FepvBb5UVTYp6S5Jn5R0Ur0NJY1LmpE0s7Cw0ELIZv3JvcEa4+PUuhUTg6SbJN1dY9pSWS+9TKl7+ZFeUfxz4IaK4o8ALwFeTXKb6Y/qbR8RUxExGhGja9euXSlss4Hj3mCN8XFq3YqJISLOiYhX1JiuBx5Ov/CXvvgfWealLga+GhFPVbz2wfTW15PAXwFnt/bPMesd7W4AdW+wxnTqOA1Vg3YjDRH1JuATZBufP75M3V3Ab1aVrU//CvgUcGUj+3Xjs/W6TjWATk9PR7FYDElRLBbdoFpHu4/ToDRo06XG55OBLwObgDng4og4JGkUuDQiPpDWKwF/C2yMiCMV298CrE0Tw+50mydW2q8bn63XuQF0sAzK+9lo43NLiSEvTgzW60ZGRqj12ZLEkSNHamxhvWxQ3s9u9UqyPjBU90Z7hBtAB8uwvZ9ODANu6RHZc3NzRMTRR2Q7OXSWG4oHy9C9n400RPTa5MbnxhWLxUyD2dJULBbzDm3guaF4sAzC+0k3Gp/z4jaGxg3KvVEza53bGAwYvnujZtY6J4YBN3T3Rs2sZU4MA86PyDazZrmNwcxsSLiNwczMjosTg5mZZTgxmJlZhhODmZllODGYmVmGE4OZmWU4MZiZWYYTg2X4Ed1m1lJikPS7ku6RdCQdta1evQsk3Sdpr6TtFeWnS7otLd8h6cRW4rHW+BHd+XFCzoePex2NPIK13gS8FHgxcCswWqfOCcADwBnAicCdwMvSdV8GtqbzfwF8sJH9+rHbneFHdOdjUMYT7jfDeNzp5mO3Jd0K/GFEHPOcCkmvBS6PiPPT5Y+kq64EFoDnR8TT1fWW40didIYf0Z2PQRlPuN8M43HvpUdinAY8WLG8Py07GXgsIp6uKq9J0rikGUkzCwsLHQt2mPkR3fmYn59vqtzaw8e9vhUTg6SbJN1dY9rSjQCXRMRURIxGxOjatWu7ueuh4Ud058MJOR8+7vWtmBgi4pyIeEWN6foG93EA2FixvCEt+zHwPEmrqsotJ35Edz6ckPPh476MRhoiVppYvvF5FbAPOJ1fNj6/PF33N2Qbn3+vkf258dkGzSCMJ9yPhu24043GZ0lvB/4cWAs8BuyOiPMlnQp8PiLemtZ7K/Apkh5KV0XEZFp+BnAtsAb4B+DdEfHkSvt147OZWfMabXz2QD1mZkOil3olmZlZH3FiMDOzDCcGMzPLcGIwM7OMvmx8lrQAHPtb9sacAvxjG8NpF8fVHMfVHMfVnEGNqxgRK/5CuC8TQyskzTTSKt9tjqs5jqs5jqs5wx6XbyWZmVmGE4OZmWUMY2KYyjuAOhxXcxxXcxxXc4Y6rqFrYzAzs+UN4xWDmZktw4nBzMwyBjYxSLpA0n2S9kraXmP9SZJ2pOtvk1TqQkwbJX1L0r2S7pH072rUeZOkxyXtTqePdjqudL+zkr6X7rPWEK2S9On0eN0l6VVdiOnFFcdht6SfSPqDqjpdOV6SrpL0iKS7K8rWSLpR0p707+o6225L6+yRtK0LcX1C0g/S9+mrkp5XZ9tl3/MOxHW5pAMV79Vb62y77Ge3A3HtqIhpVtLuOtt28njV/G7I7Rxr5Nnc/TaRPN77AeAMfjkGxMuq6vwe8Bfp/FZgRxfiWg+8Kp1/NnB/jbjeBHwth2M2C5yyzPq3At8ABLwGuC2H9/Qhkh/odP14AW8AXgXcXVH2cWB7Or8d+FiN7daQjEeyBlidzq/ucFznAavS+Y/ViquR97wDcV1OMjb8Su/zsp/ddsdVtf6/AR/N4XjV/G7I6xwb1CuGs4G9EbEvIn5BMuZD9VCkW4Br0vnrgDdLUieDioiDEfHddP6nwPdZZpzrHrMF+GIkdpGMvre+i/t/M/BARBzvL95bEhHfBg5VFVeeQ9cAF9XY9Hzgxog4FBGPAjcCF3Qyroj4ZvxyLPVdJKMjdlWd49WIRj67HYkr/fxfDHypXftr1DLfDbmcY4OaGE4DHqxY3s+xX8BH66QfoseBk7sSHZDeunolcFuN1a+VdKekb0h6eZdCCuCbku6QNF5jfSPHtJO2Uv8Dm8fxAlgXEQfT+YeAdTXq5H3c3kdypVfLSu95J1yW3uK6qs5tkTyP128AD0fEnjrru3K8qr4bcjnHBjUx9DRJzwK+AvxBRPykavV3SW6X/BrJ6Hj/q0thvT4iXgW8Bfh9SW/o0n5XJOlE4EKSoWCr5XW8MiK5pu+pvt+SJoCngXKdKt1+zz8LvAA4CzhIctuml7yL5a8WOn68lvtu6OY5NqiJ4QCwsWJ5Q1pWs46kVcBzgR93OjBJv0Lyxpcj4n9Wr4+In0TEE+n814FfkXRKp+OKiAPp30eAr5Jc0ldq5Jh2yluA70bEw9Ur8jpeqYeXbqelfx+pUSeX4ybpEuBtwFj6hXKMBt7ztoqIhyPicEQcAf6yzv7yOl6rgN8GdtSr0+njVee7IZdzbFATw+3AmZJOT/+3uRXYWVVnJ7DUev8O4JZ6H6B2Se9hfgH4fkT8WZ06z19q65B0Nsl71NGEJemZkp69NE/SeHl3VbWdwHuUeA3weMUlbqfV/Z9cHserQuU5tA24vkadG4DzJK1Ob52cl5Z1jKQLgP8IXBgRi3XqNPKetzuuyjapt9fZXyOf3U44B/hBROyvtbLTx2uZ74Z8zrFOtLD3wkTSi+Z+kh4OE2nZFSQfFoBnkNya2Av8PXBGF2J6Pcml4F3A7nR6K3ApcGla5zLgHpLeGLuAX+9CXGek+7sz3ffS8aqMS8Bn0uP5PWC0S+/jM0m+6J9bUdb140WSmA4CT5Hcw30/SZvUzcAe4CZgTVp3FPh8xbbvS8+zvcB7uxDXXpJ7zkvn2FLvu1OBry/3nnc4rr9Oz527SL7w1lfHlS4f89ntZFxp+dVL51RF3W4er3rfDbmcY34khpmZZQzqrSQzMztOTgxmZpbhxGBmZhlODGZmluHEYGZmGU4MZmaW4cRgZmYZ/x8nzHVmzMAuywAAAABJRU5ErkJggg==\n",
|
|
"text/plain": [
|
|
"<Figure size 432x288 with 1 Axes>"
|
|
]
|
|
},
|
|
"metadata": {
|
|
"needs_background": "light"
|
|
},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAEICAYAAACQ6CLfAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAEFVJREFUeJzt3X+sX3V9x/Hny3opA2qgoLWUYoFVXbNoYXcFHHMMVBBdkMQ5cDOQOOsvsrFgNsbixE0WNRPG4sK8jIaq/FQgEGVT1pgxf6xycbUUC7SQEqn9AUJp0Vnb8tof5zR+e3N/fPv9cb69/bweyc33/Pqe874n93U/55zP95yvbBMR5XnZoAuIiMFI+CMKlfBHFCrhjyhUwh9RqIQ/olAJf6Ek3STpU4OuIwYn4Y/GSFogyZJefiCurzQJf0ShEv4BkPRXkjZK2iHpMUln19OXSPqepG2SNkn6vKRDWt5nSR+RtK5+799LOknSdyVtl3TH3uUlnSnpaUlXSnpW0gZJfzxJTe+UtKre9nclvWGSZd8k6UFJL9Svb2qZt0HSW1rGr5L05Xr0gfp1m6QXJZ0u6RJJ36l/1xckPbp3f3Syvqn2ffxKwt8wSa8DLgV+2/Ys4BxgQz17D/AXwDHA6cDZwEfGrOIc4LeA04C/BEaAPwHmA78JXNSy7Kvrdc0DLgZG6u2PrelkYBnwQeBo4AvAvZJmjrPsbODrwD/Xy14DfF3S0W38+m+uX4+0fYTt79XjpwJP1LV+Arir3k6n64s2JPzN2wPMBBZJGrK9wfYTALYfsv0/tnfb3kAVwt8b8/7P2t5u+xFgDfBN20/afgH4d+DkMct/3PZO2/9FFdr3jFPTUuALtlfa3mN7ObCT6h/MWO8A1tn+Ul3nrcCjwB90sC/22gr8k+1dtm8HHqu3E32U8DfM9nrgMuAqYKuk2yQdCyDptZK+JmmzpO3AP1C1hq22tAz/3zjjR7SMP2/7Zy3jTwHHjlPWa4DL60P+bZK2UR1JjLfssfV6Wj1FdXTRqY3e9w6zieqMHkr4B8D2LbbPoAqdgc/Us66nakUX2n4FcCWgLjZ1lKTDW8aPB34yznI/Bq62fWTLz2F1qz7WT+q6Wx0PbKyHfwYc1jLv1S3DE91COk9S6+/ZWmcn64s2JPwNk/Q6SWfV59O/oGqtX6pnzwK2Ay9Kej3w4R5s8pOSDpH0u8A7ga+Ms8wNwIcknarK4ZLeIWnWOMveB7xW0nslvVzSHwGLgK/V81cBF0oakjQMvLvlvc9Q/a4njlnnq4A/q9/zh8Bv1NvpdH3RhoS/eTOBTwPPApup/vD/up73MeC9wA6qQN7e5bY2A89TtaI3Ax+y/ejYhWyPAh8APl8vvx64ZLwV2v4p1T+Ry4GfUl10fKftZ+tFPg6cVK/nk8AtLe/9OXA18J369GLvNYWVwEKqfXI18O56O52uL9qgPMzj4CTpTODLto8bdC2TkXQJ8Kf1aVA0KC1/RKES/ohC5bA/olBp+SMK1ejdUMfMnuEF84fGnff46sPGnR4R7fsFP+OX3tnWZ0O6Cr+kc4HrgBnAv9n+9GTLL5g/xPe/MX/ceeccu7ibUiICWOkVbS/b8WG/pBnAvwBvp/qQx0WSFnW6vohoVjfn/EuA9fVNJb8EbgPO701ZEdFv3YR/HtVnwvd6mnFu7pC0VNKopNFnfrqni81FRC/1/Wq/7RHbw7aHX3n0jH5vLiLa1E34N1Ld9rnXcfzqzq6IOMB1c7X/QWChpBOoQn8h1U0pE3p89WETXtX/xk9WTfi+9ARE9F7H4be9W9KlwDeouvqW1U+XiYhpoKt+ftv38av7riNiGsnHeyMKlfBHFCrhjyhUwh9RqAPmO84m686bqBswXYARnUvLH1GohD+iUAl/RKES/ohCJfwRhTpgrvZPJjcDRfReWv6IQiX8EYVK+CMKlfBHFCrhjyhUwh9RqGnR1TeRTm4Gmup9EaVIyx9RqIQ/olAJf0ShEv6IQiX8EYVK+CMKNa27+iaTbsCIyXUVfkkbgB3AHmC37eFeFBUR/deLlv/3bT/bg/VERINyzh9RqG7Db+Cbkh6StHS8BSQtlTQqaXQXO7vcXET0SreH/WfY3ijpVcD9kh61/UDrArZHgBGAV2i2u9xeRPRIVy2/7Y3161bgbmBJL4qKiP7ruOWXdDjwMts76uG3AX/Xs8r6KN2AEd0d9s8B7pa0dz232P6PnlQVEX3XcfhtPwm8sYe1RESD0tUXUaiEP6JQCX9EoRL+iEIdtHf1daqTbsB0AcZ0lJY/olAJf0ShEv6IQiX8EYVK+CMKlav9+2Giq/q5GSimo7T8EYVK+CMKlfBHFCrhjyhUwh9RqIQ/olDp6uuBPBMwpqO0/BGFSvgjCpXwRxQq4Y8oVMIfUaiEP6JQ6errs3QDxoFqypZf0jJJWyWtaZk2W9L9ktbVr0f1t8yI6LV2DvtvAs4dM+0KYIXthcCKejwippEpw2/7AeC5MZPPB5bXw8uBd/W4rojos07P+efY3lQPb6b6xt5xSVoKLAU4lMM63FxE9FrXV/ttG/Ak80dsD9seHmJmt5uLiB7pNPxbJM0FqF+39q6kiGhCp4f99wIXA5+uX+/pWUUFSTdgDFI7XX23At8DXifpaUnvpwr9WyWtA95Sj0fENDJly2/7oglmnd3jWiKiQfl4b0ShEv6IQiX8EYVK+CMKlbv6DlDpBox+S8sfUaiEP6JQCX9EoRL+iEIl/BGFSvgjCpWuvmmok27AdAHGWGn5IwqV8EcUKuGPKFTCH1GohD+iULnaf5CZ6Kp+bgaKsdLyRxQq4Y8oVMIfUaiEP6JQCX9EoRL+iEKlq68QeSZgjNXO13Utk7RV0pqWaVdJ2ihpVf1zXn/LjIhea+ew/ybg3HGmX2t7cf1zX2/Lioh+mzL8th8AnmuglohoUDcX/C6VtLo+LThqooUkLZU0Kml0Fzu72FxE9FKn4b8eOAlYDGwCPjfRgrZHbA/bHh5iZoebi4he6yj8trfY3mP7JeAGYElvy4qIfuuoq0/SXNub6tELgDWTLR8HtnQDlmnK8Eu6FTgTOEbS08AngDMlLQYMbAA+2McaI6IPpgy/7YvGmXxjH2qJiAbl470RhUr4IwqV8EcUKuGPKFTu6otJpRvw4JWWP6JQCX9EoRL+iEIl/BGFSvgjCpXwRxQqXX3RsU66AdMFeOBIyx9RqIQ/olAJf0ShEv6IQiX8EYXK1f7oi4mu6udmoANHWv6IQiX8EYVK+CMKlfBHFCrhjyhUwh9RqHa+sWc+8EVgDtU39IzYvk7SbOB2YAHVt/a8x/bz/Ss1DgZ5JuCBo52Wfzdwue1FwGnARyUtAq4AVtheCKyoxyNimpgy/LY32f5BPbwDWAvMA84HlteLLQfe1a8iI6L39uucX9IC4GRgJTCn5Zt6N1OdFkTENNF2+CUdAdwJXGZ7e+s826a6HjDe+5ZKGpU0uoudXRUbEb3TVvglDVEF/2bbd9WTt0iaW8+fC2wd7722R2wP2x4eYmYvao6IHpgy/JJE9ZXca21f0zLrXuDievhi4J7elxcR/aLqiH2SBaQzgP8GHgZeqidfSXXefwdwPPAUVVffc5Ot6xWa7VN1drc1R4HSDdielV7Bdj+ndpadsp/f9reBiVaWJEdMU/mEX0ShEv6IQiX8EYVK+CMKlfBHFCoP8IxpIXcD9l5a/ohCJfwRhUr4IwqV8EcUKuGPKFTCH1GodPXFtNdJN2C6ANPyRxQr4Y8oVMIfUaiEP6JQCX9EoXK1Pw5qE13Vz81AafkjipXwRxQq4Y8oVMIfUaiEP6JQCX9Eoabs6pM0H/gi1VdwGxixfZ2kq4APAM/Ui15p+75+FRrRS3kmYHv9/LuBy23/QNIs4CFJ99fzrrX9j/0rLyL6pZ3v6tsEbKqHd0haC8zrd2ER0V/7dc4vaQFwMtU39AJcKmm1pGWSjupxbRHRR22HX9IRwJ3AZba3A9cDJwGLqY4MPjfB+5ZKGpU0uoudPSg5InqhrfBLGqIK/s227wKwvcX2HtsvATcAS8Z7r+0R28O2h4eY2au6I6JLU4ZfkoAbgbW2r2mZPrdlsQuANb0vLyL6pZ2r/b8DvA94WNLePpArgYskLabq/tsAfLAvFUY0rJRuwHau9n8b0Diz0qcfMY3lE34RhUr4IwqV8EcUKuGPKFTCH1GoPMAzYj8cTN2AafkjCpXwRxQq4Y8oVMIfUaiEP6JQCX9EodLVF9Ej060bMC1/RKES/ohCJfwRhUr4IwqV8EcUKuGPKFS6+iIa0Ek3YL+7ANPyRxQq4Y8oVMIfUaiEP6JQCX9Eoaa82i/pUOABYGa9/Fdtf0LSCcBtwNHAQ8D7bP+yn8VGHIwmuqrf75uB2mn5dwJn2X4j1ddxnyvpNOAzwLW2fx14Hnh/19VERGOmDL8rL9ajQ/WPgbOAr9bTlwPv6kuFEdEXbZ3zS5pRf0PvVuB+4Algm+3d9SJPA/P6U2JE9ENb4be9x/Zi4DhgCfD6djcgaamkUUmju9jZYZkR0Wv7dbXf9jbgW8DpwJGS9l4wPA7YOMF7RmwP2x4eYmZXxUZE70wZfkmvlHRkPfxrwFuBtVT/BN5dL3YxcE+/ioyI3mvnxp65wHJJM6j+Wdxh+2uSfgTcJulTwP8CN/axzojidHIz0JJzft72+qcMv+3VwMnjTH+S6vw/IqahfMIvolAJf0ShEv6IQiX8EYVK+CMKJdvNbUx6BniqHj0GeLaxjU8sdewrdexrutXxGtuvbGeFjYZ/nw1Lo7aHB7Lx1JE6UkcO+yNKlfBHFGqQ4R8Z4LZbpY59pY59HbR1DOycPyIGK4f9EYVK+CMKNZDwSzpX0mOS1ku6YhA11HVskPSwpFWSRhvc7jJJWyWtaZk2W9L9ktbVr0cNqI6rJG2s98kqSec1UMd8Sd+S9CNJj0j683p6o/tkkjoa3SeSDpX0fUk/rOv4ZD39BEkr69zcLumQrjZku9EfYAbVMwBPBA4BfggsarqOupYNwDED2O6bgVOANS3TPgtcUQ9fAXxmQHVcBXys4f0xFzilHp4FPA4sanqfTFJHo/sEEHBEPTwErAROA+4ALqyn/yvw4W62M4iWfwmw3vaTrp7zfxtw/gDqGBjbDwDPjZl8PtVTkKGhpyFPUEfjbG+y/YN6eAfVk6Lm0fA+maSORrnS9ydmDyL884Aft4wP8sm/Br4p6SFJSwdUw15zbG+qhzcDcwZYy6WSVtenBX0//WglaQHVw2NWMsB9MqYOaHifNPHE7NIv+J1h+xTg7cBHJb150AVB9Z+f6h/TIFwPnET1BS2bgM81tWFJRwB3ApfZ3t46r8l9Mk4dje8Td/HE7HYNIvwbgfkt4xM++bffbG+sX7cCdzPYx5JtkTQXoH7dOogibG+p//BeAm6goX0iaYgqcDfbvque3Pg+Ga+OQe2Tetv7/cTsdg0i/A8CC+srl4cAFwL3Nl2EpMMlzdo7DLwNWDP5u/rqXqqnIMMAn4a8N2y1C2hgn0gS1QNg19q+pmVWo/tkojqa3ieNPTG7qSuYY65mnkd1JfUJ4G8GVMOJVD0NPwQeabIO4Faqw8ddVOdu76f6wtMVwDrgP4HZA6rjS8DDwGqq8M1toI4zqA7pVwOr6p/zmt4nk9TR6D4B3kD1ROzVVP9o/rblb/b7wHrgK8DMbraTj/dGFKr0C34RxUr4IwqV8EcUKuGPKFTCH1GohD+iUAl/RKH+H1kk4Yr1Qy1UAAAAAElFTkSuQmCC\n",
|
|
"text/plain": [
|
|
"<Figure size 432x288 with 1 Axes>"
|
|
]
|
|
},
|
|
"metadata": {
|
|
"needs_background": "light"
|
|
},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"%matplotlib inline\n",
|
|
"N = task_params['num_scatter_samples']\n",
|
|
"x = np.random.rand(N)\n",
|
|
"y = np.random.rand(N)\n",
|
|
"colors = np.random.rand(N)\n",
|
|
"area = (50 * np.random.rand(N))**2 # 0 to 15 point radii\n",
|
|
"plt.scatter(x, y, s=area, c=colors, alpha=0.5)\n",
|
|
"plt.title('Nice Circles')\n",
|
|
"plt.show()\n",
|
|
"\n",
|
|
"x = np.linspace(0, task_params['sin_max_value'], task_params['sin_steps'])\n",
|
|
"y = np.sin(x)\n",
|
|
"plt.plot(x, y, 'o', color='black')\n",
|
|
"plt.title('Sinus Dots')\n",
|
|
"plt.show()\n",
|
|
"\n",
|
|
"m = np.eye(32, 32, dtype=np.uint8)\n",
|
|
"plt.imshow(m)\n",
|
|
"plt.title('sample output')\n",
|
|
"plt.show()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Notice, Updating task_params is traced and updated in TRAINS\n",
|
|
"task_params['batch_size'] = 128\n",
|
|
"task_params['nb_classes'] = 10\n",
|
|
"task_params['nb_epoch'] = 6\n",
|
|
"task_params['hidden_dim'] = 512\n",
|
|
"batch_size = task_params['batch_size']\n",
|
|
"nb_classes = task_params['nb_classes']\n",
|
|
"nb_epoch = task_params['nb_epoch']"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"60000 train samples\n",
|
|
"10000 test samples\n",
|
|
"Model: \"sequential\"\n",
|
|
"_________________________________________________________________\n",
|
|
"Layer (type) Output Shape Param # \n",
|
|
"=================================================================\n",
|
|
"dense (Dense) (None, 512) 401920 \n",
|
|
"_________________________________________________________________\n",
|
|
"activation (Activation) (None, 512) 0 \n",
|
|
"_________________________________________________________________\n",
|
|
"dense_1 (Dense) (None, 512) 262656 \n",
|
|
"_________________________________________________________________\n",
|
|
"activation_1 (Activation) (None, 512) 0 \n",
|
|
"_________________________________________________________________\n",
|
|
"dense_2 (Dense) (None, 10) 5130 \n",
|
|
"_________________________________________________________________\n",
|
|
"activation_2 (Activation) (None, 10) 0 \n",
|
|
"=================================================================\n",
|
|
"Total params: 669,706\n",
|
|
"Trainable params: 669,706\n",
|
|
"Non-trainable params: 0\n",
|
|
"_________________________________________________________________\n",
|
|
"Train on 60000 samples, validate on 10000 samples\n",
|
|
"Epoch 1/6\n",
|
|
"59264/60000 [============================>.] - ETA: 0s - loss: 0.2217 - accuracy: 0.93122020-01-05 17:51:43,884 - trains - WARNING - too many indices for array\n",
|
|
"2020-01-05 17:51:43,906 - trains - WARNING - too many indices for array\n",
|
|
"2020-01-05 17:51:43,909 - trains - WARNING - too many indices for array\n",
|
|
"60000/60000 [==============================] - 6s 99us/sample - loss: 0.2206 - accuracy: 0.9316 - val_loss: 0.1300 - val_accuracy: 0.9586\n",
|
|
"Epoch 2/6\n",
|
|
"59264/60000 [============================>.] - ETA: 0s - loss: 0.0808 - accuracy: 0.97512020-01-05 17:51:49,017 - trains - WARNING - too many indices for array\n",
|
|
"2020-01-05 17:51:49,039 - trains - WARNING - too many indices for array\n",
|
|
"2020-01-05 17:51:49,043 - trains - WARNING - too many indices for array\n",
|
|
"60000/60000 [==============================] - 5s 84us/sample - loss: 0.0804 - accuracy: 0.9752 - val_loss: 0.0794 - val_accuracy: 0.9765\n",
|
|
"Epoch 3/6\n",
|
|
"59648/60000 [============================>.] - ETA: 0s - loss: 0.0542 - accuracy: 0.98342020-01-05 17:51:54,222 - trains - WARNING - too many indices for array\n",
|
|
"2020-01-05 17:51:54,245 - trains - WARNING - too many indices for array\n",
|
|
"2020-01-05 17:51:54,248 - trains - WARNING - too many indices for array\n",
|
|
"60000/60000 [==============================] - 5s 87us/sample - loss: 0.0540 - accuracy: 0.9834 - val_loss: 0.0758 - val_accuracy: 0.9782\n",
|
|
"Epoch 4/6\n",
|
|
"59392/60000 [============================>.] - ETA: 0s - loss: 0.0388 - accuracy: 0.98762020-01-05 17:51:59,298 - trains - WARNING - too many indices for array\n",
|
|
"2020-01-05 17:51:59,320 - trains - WARNING - too many indices for array\n",
|
|
"2020-01-05 17:51:59,324 - trains - WARNING - too many indices for array\n",
|
|
"60000/60000 [==============================] - 5s 84us/sample - loss: 0.0387 - accuracy: 0.9876 - val_loss: 0.0836 - val_accuracy: 0.9777\n",
|
|
"Epoch 5/6\n",
|
|
"59520/60000 [============================>.] - ETA: 0s - loss: 0.0282 - accuracy: 0.99142020-01-05 17:52:04,410 - trains - WARNING - too many indices for array\n",
|
|
"2020-01-05 17:52:04,433 - trains - WARNING - too many indices for array\n",
|
|
"2020-01-05 17:52:04,436 - trains - WARNING - too many indices for array\n",
|
|
"60000/60000 [==============================] - 5s 85us/sample - loss: 0.0280 - accuracy: 0.9915 - val_loss: 0.0754 - val_accuracy: 0.9811\n",
|
|
"Epoch 6/6\n",
|
|
"59520/60000 [============================>.] - ETA: 0s - loss: 0.0242 - accuracy: 0.99252020-01-05 17:52:09,482 - trains - WARNING - too many indices for array\n",
|
|
"2020-01-05 17:52:09,504 - trains - WARNING - too many indices for array\n",
|
|
"2020-01-05 17:52:09,507 - trains - WARNING - too many indices for array\n",
|
|
"60000/60000 [==============================] - 5s 85us/sample - loss: 0.0243 - accuracy: 0.9924 - val_loss: 0.0769 - val_accuracy: 0.9824\n",
|
|
"Test score: 0.07691321085649504\n",
|
|
"Test accuracy: 0.9824\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# the data, shuffled and split between train and test sets\n",
|
|
"(X_train, y_train), (X_test, y_test) = mnist.load_data()\n",
|
|
"\n",
|
|
"X_train = X_train.reshape(60000, 784)\n",
|
|
"X_test = X_test.reshape(10000, 784)\n",
|
|
"X_train = X_train.astype('float32')\n",
|
|
"X_test = X_test.astype('float32')\n",
|
|
"X_train /= 255.\n",
|
|
"X_test /= 255.\n",
|
|
"print(X_train.shape[0], 'train samples')\n",
|
|
"print(X_test.shape[0], 'test samples')\n",
|
|
"\n",
|
|
"# convert class vectors to binary class matrices\n",
|
|
"Y_train = to_categorical(y_train, nb_classes)\n",
|
|
"Y_test = to_categorical(y_test, nb_classes)\n",
|
|
"\n",
|
|
"hidden_dim = task_params['hidden_dim']\n",
|
|
"model = Sequential()\n",
|
|
"model.add(Dense(hidden_dim, input_shape=(784,)))\n",
|
|
"model.add(Activation('relu'))\n",
|
|
"# model.add(Dropout(0.2))\n",
|
|
"model.add(Dense(hidden_dim))\n",
|
|
"model.add(Activation('relu'))\n",
|
|
"# model.add(Dropout(0.2))\n",
|
|
"model.add(Dense(10))\n",
|
|
"model.add(Activation('softmax'))\n",
|
|
"\n",
|
|
"model.summary()\n",
|
|
"\n",
|
|
"model.compile(loss='categorical_crossentropy',\n",
|
|
" optimizer=RMSprop(),\n",
|
|
" metrics=['accuracy'])\n",
|
|
"\n",
|
|
"board = TensorBoard(histogram_freq=1, log_dir='/tmp/histogram_example')\n",
|
|
"model_store = ModelCheckpoint(filepath='/tmp/weight.{epoch}.hdf5')\n",
|
|
"\n",
|
|
"model.fit(X_train, Y_train,\n",
|
|
" batch_size=batch_size, epochs=nb_epoch,\n",
|
|
" callbacks=[board, model_store],\n",
|
|
" verbose=1, validation_data=(X_test, Y_test))\n",
|
|
"score = model.evaluate(X_test, Y_test, verbose=0)\n",
|
|
"print('Test score:', score[0])\n",
|
|
"print('Test accuracy:', score[1])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "conda_python3",
|
|
"language": "python",
|
|
"name": "conda_python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.6.9"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|