clearml/examples/frameworks/keras/jupyter.ipynb

312 lines
82 KiB
Plaintext
Raw Normal View History

2020-06-15 19:48:51 +00:00
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"# Trains - Example of integrating plots and training on jupyter notebook. \n",
"# In this example, simple graphs are shown, then an MNIST classifier is trained using Keras.\n",
"import os\n",
"import tempfile\n",
"\n",
"from tensorflow.keras import utils as np_utils\n",
"from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard\n",
"from tensorflow.keras.datasets import mnist\n",
"from tensorflow.keras.layers import Activation, Dense\n",
"from tensorflow.keras.models import Sequential\n",
"from tensorflow.keras.optimizers import RMSprop\n",
"\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"TRAINS Task: created new task id=2f9f2f08fa90427aa51e34b839e49fb6\n",
"TRAINS results page: https://demoapp.trains.allegro.ai/projects/0e152d03acf94ae4bb1f3787e293a9f5/experiments/2f9f2f08fa90427aa51e34b839e49fb6/output/log\n"
]
}
],
"source": [
"# Connecting TRAINS\n",
"from trains import Task\n",
"task = Task.init(project_name = 'examples', task_name = 'notebook example')\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"# Set script parameters\n",
"task_params = {'num_scatter_samples': 60, 'sin_max_value': 20, 'sin_steps': 30}\n",
"task_params = task.connect(task_params)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Simple plots. You can view the plots in experiments results page "
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAEICAYAAABPgw/pAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nOy9d5Bk13Wn+Z3n02eWr2rv0PAgCIAgQO8kUo6SZjgyO9pgrCRKM5J2dkaandmJiVmFYlehnV3trmJG0soOKYmkKJEUCYmgCNomCQ80utFAA+2rTbmsrKr0mc/e/SOr2pbJqsqqajTyQwDoznx5330v8/3uueeee44opejSpUuXLm98tK3uQJcuXbp06QxdQe/SpUuXW4SuoHfp0qXLLUJX0Lt06dLlFqEr6F26dOlyi9AV9C5dunS5RegKepc3FCJSFZG9m3zO/yAif7rGz35SRP63TvepS5fF6Ap6l5sGERkVkbyIJK567RdE5DsLf1dKJZVSZzfg3D8rIi/MDxgTIvJVEXnn/Dl/Wyn1C50+Z5cunaYr6F1uNnTgX23mCUXk3wD/L/DbwCCwE/gD4KNtfNbY2N516dI+XUHvcrPxfwK/ISLZxd4UESUi++f/HBOR3xWR8yJSEpHvi0hs/r23i8hTIlIUkaMi8t4l2ssAvwX8ilLqi0qpmlLKV0r9vVLq384f85si8lfzf94934efF5ELwLfmX3/nVee7KCIfX+J8PyIiR+aPe0pE7r3qvX8nImMiUhGREyLygbXdwi5vVrrWRZebjReA7wC/AfzHFY79v4C7gEeBSeBhIBKRbcBXgJ8D/hH4APAFEbldKTV9XRuPAA7wd6vs53uAO+bPtwv4KvAJ4PNAGthx/QdE5H7gz4Efnb/Ofw48JiIHgd3ArwIPKaXGRWQ3rdlKly5t07XQu9yM/Cfg10Skf6kDREQD/gfgXymlxpRSoVLqKaWUS0soH1dKPa6UipRSX6cloD+0SFO9QEEpFayyj785b803gJ8FvqGU+uy8dT+jlDqyyGc+AfyRUurZ+f5+CnCBtwMhYAN3ioiplBpVSp1ZZZ+6vMnpCnqXmw6l1CvAPwD/fpnD+mhZ1ouJ3i7gY/NujaKIFIF3AsOLHDsD9K3BF37xqj/vWKIfi/Xr16/r1w5gRCl1GvifgN8E8iLy1yIysso+dXmT0xX0Ljcr/yvwi8C2Jd4vAE1g3yLvXQT+UimVverfhFLqdxY59mlaVvKPr7J/V6cpvbhEPxbr1/9+Xb/iSqnPAiilPqOUeict4VfA/7HKPnV5k9MV9C43JfMW6+eA/3GJ9yNa/uj/W0RGREQXkUdExAb+CvhREfnB+dcdEXmviGxfpJ0SLRfP74vIj4tIXERMEfmIiPznNrv7aeCDIvLPRMQQkV4Recsix/0J8Msi8rC0SIjID4tISkQOisj75/vfBBpA1Ob5u3QBuoLe5ebmt4DEMu//BnAMeB6YpWXRakqpi7RCDv8DME3LMv63LPF7V0r9LvBvaC3CLhz/q8CX2umkUuoCLf/8r8/34whw3yLHvUBr1vFfgTngNPDx+bdt4HdozTwmgQHgf2nn/F26LCDdAhddunTpcmvQtdC7dOnS5RahK+hdunTpcovQFfQuXbp0uUXoCnqXLl263CJs2db/vr4+tXv37q06fZcuXbq8IXnxxRcLSqlFd1FvmaDv3r2bF154YatO36VLly5vSETk/FLvdV0uXbp06XKL0BX0Ll26dLlF6Ap6ly5dutwidPOht4FSijCMiJRC1zU0EURkq7vVpUuXLtfQFfTriCLF1GyFwlyVC5NzXJiYY7pYvXKAAtPQGe5Ps3ukh+G+DIO9KbKp2NZ1ukuXLl3oCvpl6k2P42cnefLIOcq1BmpeuOOOyWBvCu0qizwMI2ZLNS5NFVG0LPh92/t4+z272b2tB13rerI2imbDozBRpDBR5NKZPLP5EoEXommCFTMZ3tnH8O4+egcz9Axm0PXud9HlzcOWJed68MEH1c0Qtlhvenz38BleOH6RKIrIJB3ijrWqNiKlKFUaNN2AZMLm/Q8d4N4D29C0rlumEyilmB6f49jTp3nlubOEUYSKFE7MwnJMNE1QQBRENOouoR+CCMlMjIfedwe3vWU3iZSz1ZfRpUtHEJEXlVIPLvrem1XQlVKcvljgse8co+H69GYTGB2w5hquz2ypzv4d/fzIu+/qumLWydx0ma9//jkunpxEN3WyfSkMo71Sm826S2m2hqZpPPj+O3j4A3dj2d1JaZc3Nl1Bv46m5/PEU69z+PVLZFMxErHVWeQroZRiplhHAT/0rju478C27iLqKgnDiJefPsWhLx9GNzSyfak138MwCClMFMkNpPnIzz7K8K6+Dvd261BKUfKa5JsVxutlJutlvDAgRGFqOmnDZkcqx4CTpN9JYendutNvdLqCfhW1hsdff+0wl6aKDPWmNtQt4noB08Ua73twP+95YH9X1NvEa/p85a+e5PQrl+gbzGB2yKquFGvUK00++LGHufeRN/b3kW9UeH76As9NX6AZ+oAgKCzNQBcNpOUKDKKIQIUIrWvdmczxrqG9HMwMdsX9Dcpygr7ikyIifw78CJBXSt29yPsC/B6tii114ONKqcPr6/LGUG96fPrxF5ieqzK8DouvXWzLYKg3xbefP00YKt7/tgNvaBHZDLymz5f+/BAXT08xtKOno/crlU3gxG2e+NwzBH7IA++5ve3PVn2XfKPKZL3CaGWOeuARRgpdE2K6xa5UluF4moF4kqRhbcj3rJTiRCnPoYnTnKnMoKPRY8foseNtfT5SikKzyl+efgFHN3nX4B4eHthN2uquL9wqrGihi8i7gSrwF0sI+g8Bv0ZL0B8Gfk8p9fBKJ95sC90PQj79+Atcmiox2Jvc0HNFkaJe96jXXSrlBuVKk1Ktyb7+LHsGe+jpSbJzZy99fSn6+lPEOuzyeaMShhGP/fkhzr42zsC23IYNfoEfMD1e5If/+Tu486G9Sx7XDHxemZ3iO+NnmGxU0BAipXAMA0Na+xEipQhVRCMM0BAUil4nwXuH93Jv3zBxozPfbdFt8KXzL/PK3CRJwyJrxdZ1f9wwYNatY2o6P7nnXt7S03ULvlFYl4WulPquiOxe5pCP0hJ7BTwjIlkRGVZKTayptxvEk0fOMjo+x0h/asPO4boB+XyZ8bE5giBEAbomGIZO3DQ5M10kZduUSw1OvD6OzG9Quuvubbzl/t0MDWWWfKiCqIYfVVGqFcGhYWDpWTQxN+x6NpuXnzrF6VfHOm6ZX49hGvQOZnjib55lZE8/2b5rfxNlr8mh8bM8NXUePwzJWA7b4um2+qSUohZ4fP7sMb40+ipvG9jB+7btI9emFb1Ye0dmxvji6MuEKmJ7fOnfyGqwdYPheJpG4PPp0y9yLDfOR3ffQ8bqLuK/kemEc3IbraK6C1yaf+0GQReRTwCfANi5c2cHTt0e49MlvvviGQZ7kxsiFL4fcn60wNRUCQDHMbEX8ftGojg7W+ShPSNktdYDHoYRrx0f5+WjFxkeyfHhj9zLwEAaLyxTdl+j4o9S80cJoiqtTA0KkPn/K2xjgKS5m5S1j7R14A0r8LP5Mt/58ov0DXZGsFbCckw0XXjic8/wT375A4gI337hJC+MXqLQ30AcoT+WwNRW52cWEZKmTdK0CaKQZ/MXeGH6Ej+x524e6N9+zX6GlQijiMcuvMKTU+fosxPEjM5/tzHDZLue4fVSntFXvssv3P52RuKZjp+ny+awqTFcSqk/Bv4YWi6XzTinH4R86dvHiDtmR8ISr0YpmJurcerkJEEQEo8v7zt1TINK0+N8ocTegRwAuq7R25tEKcXcbIVP/80XuO8dNbIjk4iALjEMiePoQze0rVREGDWZbRym0HgWXWL0xx6hJ3Y/tp7r6LVuJEopvvG3z2KYescWQNsh15/m/KlJXnvxHMk9OT75neeZcxvs8LIcvHdo3e0bms5wPE0z8PnsqSMcnZngn+69h6y9shUcRhGfHz3C89MX2RbPrGogWC0iwmAsRdFr8Af
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAEICAYAAABS0fM3AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAdyElEQVR4nO3dfZAcd33n8fdnrdjc8CjZW0K2tDM2Ns+XGLT4IBAgwU9QnAUJMeKWQubhtkziu8tB7iJu64jPV5sYqAsUOR9hA45NdgqLmOOs46CMH3DB5SLHayIb22BrLXbXErK9QbbBLDGW9nt/dK88Gs3szmgeenb686rqUs+vfz39VU9Pf7f795v+KSIwM7P8Gsg6ADMzy5YTgZlZzjkRmJnlnBOBmVnOORGYmeWcE4GZWc45EViuSBqR9K2s4zDrJU4E1nckvUHS/5P0hKSDkv5W0msAIqIcEed3OZ7LJT0t6Wfp9ICk/y5pQxPvEZLO7GScll9OBNZXJD0P+Drw58A64DTgvwBPZRkXsCMinksS0zuBFwJ3NpMMzDrFicD6zYsBIuLLEXE4In4REd+KiLsBJF0i6f8uVU7/0r5U0h5Jj0u6SpLSZZdLmqyoW0rrr6l4r73pX/k/kjSyUnAR8XRE3Au8G5gHPlrx/v9a0nR6FbNT0qlp+XfSKndJelLSuyWdIunracwHJX1Xkr/Pdlx84Fi/eQA4LOlaSW+VtLaBdd4OvAb4VeBi4IKVVpD0bOCzwFvTv/R/HdjdaJARcRi4AfiN9P1+C/jTdPsbgFngurTuG9PVfi0inhMRO0gSyD5gEFgP/CfAz4ux4+JEYH0lIn4KvIHkpPiXwHz61/X6ZVa7MiIej4g54NvA2Q1ubhF4paR/FhEH0r/0m/FjkltFACPA1RHxvYh4CvgY8DpJpTrrPk2SMIrpVcZ3ww8Os+PkRGB9JyJ+EBGXRMRG4JXAqcBnllnl4Yr5BeA5DWzj5yS3dy4FDkj6P5Je2mSopwEH0/lTSa4Clt7/SeAnaZ1aPgVMA99Kb09tb3LbZkc4EVhfi4gfAteQJIRm/RwoVLx+YdV73xgR55H8Zf5DkiuQhqT38/8l8N206MdAsWL5s4GTgf211o+In0XERyPiDOAi4COS3tLo9s0qORFYX5H0UkkflbQxfb0JeA+w6zjebjfwRklDkp5PcrtmaTvrJW1JT9hPAU+S3CpaKb41kl4GfJkksfxZuujLwPslnS3pJOBPgNsjYiZd/ghwRsX7vF3SmWnD9hPA4Ua2b1aLE4H1m58B/wK4XdLPSRLAPVT0zmlURNwE7ADuBu4k6Za6ZAD4CMlf8geBNwEfXubt3i3pSZKT9k6S2z6bI+LH6bZuBv4z8FXgAPAiYGvF+pcD16a9hC4GzgJuJklAfwf8j4j4drP/RzMAuX3JzCzffEVgZpZzTgRmZjnnRGBmlnNOBGZmObcm6wCOxymnnBKlUinrMMzMVpU777zzHyNisLp8VSaCUqnE1NRU1mGYma0qkmZrlfvWkJlZzjkRmJnlnBOBmVnOtSURSLpa0qOS7qmzXJI+mw66cbekV1cs25YOCrJH0rZ2xGNmZo1r1xXBNcCFyyx/K8mzUc4CRoHPAUhaB/wxybNhzgH+uMGBRMzMrE3akggi4js881z1WrYAX4rELuAF6VitFwA3RcTBiHgMuInlE0pPKZfLlEolBgYGKJVKlMvlrEMyM2tat7qPngY8VPF6X1pWr/wYkkZJriYYGhrqTJRNKJfLjI6OsrCwAMDs7Cyjo6MAjIysOHStmVnPWDWNxRExERHDETE8OHjM7yG6bmxs7EgSWLKwsMDY2FhGEeWHr8TM2qtbiWA/sKni9ca0rF55z5ubm2uq3Npj6UpsdnaWiDhyJeZkYHb8upUIdgLvS3sPvRZ4IiIOADcC50tamzYSn5+W9bx6t6d64bZVP/OVWLYavRrzVdsqExEtTyTD7B0Ania5z/9BkkG9L02XC7gKeBD4PjBcse4HSAbhngbe38j2Nm/eHFmbnJyMQqEQwJGpUCjE5ORk1qH1NUlH7fOlSVLWofW9Ro95fzd6FzAVtc7htQp7feqFRBCRHPDFYjEkRbFY9IHeBcVisWYiKBaLWYfW9xrd9/6Mele9RLBqGot70cjICDMzMywuLjIzM+PeQl0wPj5OoVA4qqxQKDA+Pl6zvm9RtE+j7WJuP2u/jh/HtbJDr0+9ckVg2Wj0Ssy3KNrLVwTZaOdxjG8NWd74hNRebiPIRjuP43qJwLeGrG/5FkV7jYyMMDExQbFYRBLFYpGJiYljbok2Ws8a043j2ImgS3yvuvvcxbf9Gm0Xc/tZ+3TjOHYi6AL/CCobzTYsm/WirhzHte4X9fq02toIfK86O+7ia/2gXccxddoIlCxbXYaHh2M1jVk8MDBArf0sicXFxQwiMrM8knRnRAxXl/vWUBf4XrWZ9TIngi7wvWoz62VOBF3g7nRm1svcRmBmlhNuIzAzs5qcCMzMcs6JwMws55wIzMxyri2JQNKFku6XNC1pe43ln5a0O50ekPR4xbLDFct2tiMeMzNr3JpW30DSCSTDUJ5HMkzlHZJ2RsR9S3Ui4t9X1P83wKsq3uIXEXF2q3GYmdnxaccVwTnAdETsjYhfAtcBW5ap/x6SMY7NzKwHtCMRnAY8VPF6X1p2DElF4HTg1oriZ0makrRL0jvqbUTSaFpvan5+vg1hm5kZdL+xeCtwfUQcrigrpj9w+FfAZyS9qNaKETEREcMRMTw4ONiNWM3McqEdiWA/sKni9ca0rJatVN0Wioj96b97gds4uv3AzMw6rB2J4A7gLEmnSzqR5GR/TO8fSS8F1gJ/V1G2VtJJ6fwpwOuB+6rXNTOzzmm511BEHJJ0GXAjcAJwdUTcK+kKkkEQlpLCVuC6OPrhRi8DPi9pkSQpXVnZ28jMzDrPD50zM8sJP3TOzMxqciIwM8s5JwIzs5xzIqihXC5TKpUYGBigVCpRLpezDsnMrGNa7jXUb8rlMqOjoywsLAAwOzvL6OgogIeWNLO+5CuCKmNjY0eSwJKFhQXGxsYyisjMrLOcCKrMzc01VW5mtto5EVQZGhpqqtzMbLVzIqgyPj5OoVA4qqxQKDA+Pp5RRGZmneVEUGVkZISJiQmKxSKSKBaLTExMuKG4C9xbyywbfsSE9YTq3lqQXIk5CZu1jx8xYT3NvbXMsuNEYD3BvbXMsuNEYD3BvbXMsuNEYD3BvbXMsuNEYD3BvbWyk3Vvray3b0BEtDwBFwL3A9PA9hrLLwHmgd3p9KGKZduAPem0rZHtbd68OcysdZOTk1EoFAI4MhUKhZicnMzF9vOGZNTIY86pLXcflXQC8ABwHrCPZAzj90TFkJOSLgGGI+KyqnXXAVPAcHoQ3AlsjojHltumu4+atUepVGJ2dvaY8mKxyMzMTN9vP2862X30HGA6IvZGxC+B64AtDa57AXBTRBxMT/43kVxdmFkXZN1bK+vtW6IdieA04KGK1/vSsmq/I+luSddL2tTkukgalTQlaWp+fr4NYZtZ1r21st6+JbrVWPy/gVJE/CrJX/3XNvsGETEREcMRMTw4ONj2AM3yKOveWllv3xLtSAT7gU0VrzemZUdExE8i4qn05ReAzY2ua2adk3Vvray3b6laLcjNTCSjnO0FTgdOBO4CXlFVZ0PF/DuBXen8OuBHwNp0+hGwbqVtuteQtdvk5GQUi8WQFMVi0b1WrCu6fdxRp9dQy0NVRsQhSZcBNwInAFdHxL2Srkg3uhP4t5IuAg4BB0m6kxIRByX9V5KeRgBXRMTBVmMya4aHJ7Us9NJx56ePWu65C6NlIYvjzk8fNavDXRgtC7103DkRWO65C6NloZeOOycCyz13YbQs9NJx50RguecujJaFXjru3FjcY8rlMmNjY8zNzTE0NMT4+LhPSGbWFvUai1vuPmrt00vdycwsP3xrqId43F4zy4ITQQ/ppe5kZpYfTgQ9pJe6k5lZfjgR9JBe6k5mZvnhRNBDeqk7mZnlh7uPmpnlhJ81ZGZmNTkRmJnlnBO
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAEICAYAAACZA4KlAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAQTUlEQVR4nO3de4wd5X3G8e8TZzHlJjAQYoyJgRJSq0oM3RpIKaWQBEKoAImmkDYCicbkgloqopZSpUAbKogaKFUqmqUgnIRrAghEaIFaUWkSalioMSbckVEwxuZuQxpjzNM/Zqwcr87ZXZ/rrt/nI63OnLn+drTPvnPmPTMj20TEtu99gy4gIvojYY8oRMIeUYiEPaIQCXtEIRL2iEIk7IWSdK2krw+6juifhD36RtI8SZb0/qm4vm1dwh5RiIR9ACT9laRVktZLelLSMfX4hZLul/SGpNWSviVpu4blLOnLkp6ul/17SQdI+qmkdZJu3jy/pKMkvSDpfEmvSFop6Y/HqekEScvqbf9U0kfHmffjkh6U9Gb9+vGGaSslfaLh/YWSvle/va9+fUPSW5IOl3SGpJ/Uv+ubkp7YvD/aWd9E+75kCXufSToIOBv4bds7A8cCK+vJm4C/APYADgeOAb48ZhXHAr8FHAb8JTAC/AkwF/hN4LSGeT9Yr2sOcDowUm9/bE0HA9cAZwG7A98G7pA0s8m8s4AfAv9cz3sZ8ENJu0/i1z+yft3V9k6276/fHwo8W9d6AXBrvZ121xdNJOz9twmYCcyXNGR7pe1nAWw/ZPt/bL9reyVV6H5vzPLfsL3O9mPACuAe28/ZfhP4d+DgMfN/zfYG2/9FFdLPNqlpEfBt20ttb7K9GNhA9Q9lrM8AT9v+bl3nDcATwB+0sS82Wwv8k+2Ntm8Cnqy3E12UsPeZ7WeAc4ALgbWSbpS0N4CkD0u6U9JLktYB/0DV2jVa0zD8f03e79Tw/nXbbze8fx7Yu0lZHwLOrQ/h35D0BtWRQrN5967X0+h5qqOHdq3ylldktaozOpCwD4Dt620fQRUyA5fWk66kaiUPtL0LcD6gDja1m6QdG97vC7zYZL6fAxfb3rXhZ4e61R7rxbruRvsCq+rht4EdGqZ9sGG41SWWcyQ1/p6NdbazvmgiYe8zSQdJOrr+PPxLqtb4vXryzsA64C1JHwG+1IVNXiRpO0m/C5wAfL/JPFcBX5R0qCo7SvqMpJ2bzHsX8GFJn5P0fkl/BMwH7qynLwNOlTQkaRg4pWHZl6l+1/3HrPMDwJ/Vy/wh8Bv1dtpdXzSRsPffTOAS4BXgJao/9L+up30V+BywniqAN3W4rZeA16layeuAL9p+YuxMtkeBLwDfqud/Bjij2Qptv0r1T+Nc4FWqk4Qn2H6lnuVrwAH1ei4Crm9Y9hfAxcBP6o8Lm88JLAUOpNonFwOn1Ntpd33RhHLzim2TpKOA79neZ9C1jEfSGcCf1h9roofSskcUImGPKEQO4yMKkZY9ohB9vVpoj1kzPG/uUNNpTy3foen4iJi8X/I273hD0+9mdBR2SccBVwAzgH+zfcl488+bO8QDd89tOu3YvRd0UkpEAEu9pOW0tg/jJc0A/gX4NNWXKk6TNL/d9UVEb3XymX0h8Ex9EcY7wI3Aid0pKyK6rZOwz6H6TvVmL9DkYghJiySNShp9+dVNHWwuIjrR87PxtkdsD9se3nP3Gb3eXES00EnYV1FdBrnZPvzqyqeImGI6ORv/IHCgpP2oQn4q1UUcLT21fIeWZ93vfnFZy+Vypj6ic22H3fa7ks4G7qbqerumvntKRExBHfWz276LX113HBFTWL4uG1GIhD2iEAl7RCES9ohCTJlnZI3XvdaqWy5dchGTl5Y9ohAJe0QhEvaIQiTsEYVI2CMKMWXOxo8nF89EdC4te0QhEvaIQiTsEYVI2CMKkbBHFCJhjyjEtOh6a6Wdi2cmWi5iW5WWPaIQCXtEIRL2iEIk7BGFSNgjCpGwRxRiWne9jSfdchFb6ijsklYC64FNwLu2h7tRVER0Xzda9t+3/UoX1hMRPZTP7BGF6DTsBu6R9JCkRc1mkLRI0qik0Y1s6HBzEdGuTg/jj7C9StIHgHslPWH7vsYZbI8AIwC7aJY73F5EtKmjlt32qvp1LXAbsLAbRUVE97XdskvaEXif7fX18KeAv+taZT2UbrkoUSeH8XsBt0navJ7rbf9HV6qKiK5rO+y2nwM+1sVaIqKH0vUWUYiEPaIQCXtEIRL2iEJss1e9taudbrl0ycV0kJY9ohAJe0QhEvaIQiTsEYVI2CMKkbPxW6HVWfdcPBPTQVr2iEIk7BGFSNgjCpGwRxQiYY8oRMIeUYh0vXVB7mkX00Fa9ohCJOwRhUjYIwqRsEcUImGPKETCHlGIdL31WLrlYqqYsGWXdI2ktZJWNIybJeleSU/Xr7v1tsyI6NRkDuOvBY4bM+48YIntA4El9fuImMImDHv9vPXXxow+EVhcDy8GTupyXRHRZe1+Zt/L9up6+CWqJ7o2JWkRsAhge3Zoc3MR0amOz8bbNuBxpo/YHrY9PMTMTjcXEW1qN+xrJM0GqF/Xdq+kiOiFdg/j7wBOBy6pX2/vWkUFSbdc9NNkut5uAO4HDpL0gqQzqUL+SUlPA5+o30fEFDZhy277tBaTjulyLRHRQ/m6bEQhEvaIQiTsEYVI2CMKkavepqh0y0W3pWWPKETCHlGIhD2iEAl7RCES9ohCJOwRhUjX2zTUTrdcuuQiLXtEIRL2iEIk7BGFSNgjCpGwRxQiZ+O3Ma3OuufimUjLHlGIhD2iEAl7RCES9ohCJOwRhUjYIwqRrrdC5J52MZnHP10jaa2kFQ3jLpS0StKy+uf43pYZEZ2azGH8tcBxTcZfbntB/XNXd8uKiG6bMOy27wNe60MtEdFDnZygO1vS8vowf7dWM0laJGlU0uhGNnSwuYjoRLthvxI4AFgArAa+2WpG2yO2h20PDzGzzc1FRKfaCrvtNbY32X4PuApY2N2yIqLb2up6kzTb9ur67cnAivHmj6kt3XJlmDDskm4AjgL2kPQCcAFwlKQFgIGVwFk9rDEiumDCsNs+rcnoq3tQS0T0UL4uG1GIhD2iEAl7RCES9ohC5Kq3GFe65bYdadkjCpGwRxQiYY8oRMIeUYiEPaIQCXtEIdL1Fm1rp1suXXKDk5Y9ohAJe0QhEvaIQiTsEYVI2CMKkbPx0ROtzrrn4pnBScseUYiEPaIQCXtEIRL2iEIk7BGFSNgjCjGZJ8LMBb4D7EX1BJgR21dImgXcBMyjeirMZ22/3rtSY1uQe9oNzmRa9neBc23PBw4DviJpPnAesMT2gcCS+n1ETFETht32atsP18PrgceBOcCJwOJ6tsXASb0qMiI6t1Wf2SXNAw4GlgJ7NTzJ9SWqw/yImKImHXZJOwG3AOfYXtc4zbapPs83W26RpFFJoxvZ0FGxEdG+SYVd0hBV0K+zfWs9eo2k2fX02cDaZsvaHrE9bHt4iJndqDki2jBh2CWJ6hHNj9u+rGHSHcDp9fDpwO3dLy8iukXVEfg4M0hHAP8NPAq8V48+n+pz+83AvsDzVF1vr423rl00y4fqmE5rjgKlW25ylnoJ6/yamk2bsJ/d9o+BpgsDSW7ENJFv0EUUImGPKETCHlGIhD2iEAl7RCFyw8mYFnK1XOfSskcUImGPKETCHlGIhD2iEAl7RCES9ohCpOstpr12uuVK7JJLyx5RiIQ9ohAJe0QhEvaIQiTsEYXI2fjYprU6617ixTNp2SMKkbBHFCJhjyhEwh5RiIQ9ohAJe0QhJux6kzQX+A7VI5kNjNi+QtKFwBeAl+tZz7d9V68KjeimEu9pN5l+9neBc20/LGln4CFJ99bTLrf9j70rLyK6ZTLPelsNrK6H10t6HJjT68Iioru26jO7pHnAwVRPcAU4W9JySddI2q3LtUVEF0067JJ2Am4BzrG9DrgSOABYQNXyf7PFcoskjUoa3ciGLpQcEe2YVNglDVEF/TrbtwLYXmN7k+33gKuAhc2WtT1ie9j28BAzu1V
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"%matplotlib inline\n",
"N = task_params['num_scatter_samples']\n",
"x = np.random.rand(N)\n",
"y = np.random.rand(N)\n",
"colors = np.random.rand(N)\n",
"area = (50 * np.random.rand(N))**2 # 0 to 15 point radii\n",
"plt.scatter(x, y, s=area, c=colors, alpha=0.5)\n",
"plt.title('Nice Circles')\n",
"plt.show()\n",
"\n",
"x = np.linspace(0, task_params['sin_max_value'], task_params['sin_steps'])\n",
"y = np.sin(x)\n",
"plt.plot(x, y, 'o', color='black')\n",
"plt.title('Sinus Dots')\n",
"plt.show()\n",
"\n",
"m = np.eye(32, 32, dtype=np.uint8)\n",
"plt.imshow(m)\n",
"plt.title('sample output')\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Keras training example\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"# Notice, Updating task_params is traced and updated in TRAINS\n",
"task_params['batch_size'] = 128\n",
"task_params['nb_classes'] = 10\n",
"task_params['nb_epoch'] = 6\n",
"task_params['hidden_dim'] = 512\n",
"batch_size = task_params['batch_size']\n",
"nb_classes = task_params['nb_classes']\n",
"nb_epoch = task_params['nb_epoch']\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"60000 train samples\n",
"10000 test samples\n",
"Model: \"sequential\"\n",
"_________________________________________________________________\n",
"Layer (type) Output Shape Param # \n",
"=================================================================\n",
"dense (Dense) (None, 512) 401920 \n",
"_________________________________________________________________\n",
"activation (Activation) (None, 512) 0 \n",
"_________________________________________________________________\n",
"dense_1 (Dense) (None, 512) 262656 \n",
"_________________________________________________________________\n",
"activation_1 (Activation) (None, 512) 0 \n",
"_________________________________________________________________\n",
"dense_2 (Dense) (None, 10) 5130 \n",
"_________________________________________________________________\n",
"activation_2 (Activation) (None, 10) 0 \n",
"=================================================================\n",
"Total params: 669,706\n",
"Trainable params: 669,706\n",
"Non-trainable params: 0\n",
"_________________________________________________________________\n",
"Train on 60000 samples, validate on 10000 samples\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING: Logging before flag parsing goes to stderr.\n",
"W0615 20:51:48.301550 139739992581888 ag_logging.py:146] Entity <function Function._initialize_uninitialized_variables.<locals>.initialize_variables at 0x7f174c2f7a60> could not be transformed and will be executed as-is. Please report this to the AutoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: module 'gast' has no attribute 'Num'\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/6\n",
"WARNING: Entity <function Function._initialize_uninitialized_variables.<locals>.initialize_variables at 0x7f174c2f7a60> could not be transformed and will be executed as-is. Please report this to the AutoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: module 'gast' has no attribute 'Num'\n",
"60000/60000 [==============================] - 7s 110us/sample - loss: 0.2210 - accuracy: 0.9313 - val_loss: 0.1319 - val_accuracy: 0.9581\n",
"Epoch 2/6\n",
"60000/60000 [==============================] - 5s 85us/sample - loss: 0.0814 - accuracy: 0.9756 - val_loss: 0.0814 - val_accuracy: 0.9773\n",
"Epoch 3/6\n",
"60000/60000 [==============================] - 6s 92us/sample - loss: 0.0541 - accuracy: 0.9832 - val_loss: 0.0719 - val_accuracy: 0.9789\n",
"Epoch 4/6\n",
"60000/60000 [==============================] - 6s 92us/sample - loss: 0.0377 - accuracy: 0.9884 - val_loss: 0.0879 - val_accuracy: 0.9769\n",
"Epoch 5/6\n",
"60000/60000 [==============================] - 5s 83us/sample - loss: 0.0290 - accuracy: 0.9911 - val_loss: 0.0713 - val_accuracy: 0.9812\n",
"Epoch 6/6\n",
"60000/60000 [==============================] - 5s 86us/sample - loss: 0.0238 - accuracy: 0.9927 - val_loss: 0.0900 - val_accuracy: 0.9804\n",
"Test score: 0.09002585870867187\n",
"Test accuracy: 0.9804\n"
]
}
],
"source": [
"# the data, shuffled and split between train and test sets\n",
"(X_train, y_train), (X_test, y_test) = mnist.load_data()\n",
"\n",
"X_train = X_train.reshape(60000, 784)\n",
"X_test = X_test.reshape(10000, 784)\n",
"X_train = X_train.astype('float32')\n",
"X_test = X_test.astype('float32')\n",
"X_train /= 255.\n",
"X_test /= 255.\n",
"print(X_train.shape[0], 'train samples')\n",
"print(X_test.shape[0], 'test samples')\n",
"\n",
"# convert class vectors to binary class matrices\n",
"Y_train = np_utils.to_categorical(y_train, nb_classes)\n",
"Y_test = np_utils.to_categorical(y_test, nb_classes)\n",
"\n",
"hidden_dim = task_params['hidden_dim']\n",
"model = Sequential()\n",
"model.add(Dense(hidden_dim, input_shape=(784,)))\n",
"model.add(Activation('relu'))\n",
"# model.add(Dropout(0.2))\n",
"model.add(Dense(hidden_dim))\n",
"model.add(Activation('relu'))\n",
"# model.add(Dropout(0.2))\n",
"model.add(Dense(10))\n",
"model.add(Activation('softmax'))\n",
"\n",
"model.summary()\n",
"\n",
"model.compile(loss='categorical_crossentropy',\n",
" optimizer=RMSprop(),\n",
" metrics=['accuracy'])\n",
"\n",
"board = TensorBoard(histogram_freq=1, log_dir=os.path.join(tempfile.gettempdir(), 'histogram_example'))\n",
"model_store = ModelCheckpoint(filepath=os.path.join(tempfile.gettempdir(), 'weight.{epoch}.hdf5'))\n",
"\n",
"model.fit(X_train, Y_train,\n",
" batch_size=batch_size, epochs=nb_epoch,\n",
" callbacks=[board, model_store],\n",
" verbose=1, validation_data=(X_test, Y_test))\n",
"score = model.evaluate(X_test, Y_test, verbose=0)\n",
"print('Test score:', score[0])\n",
"print('Test accuracy:', score[1])\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.8"
},
"pycharm": {
"stem_cell": {
"cell_type": "raw",
"metadata": {
"collapsed": false
},
"source": []
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}