Update pytorch example trains version and reduced default number of epochs

This commit is contained in:
allegroai 2020-09-05 16:31:44 +03:00
parent 03e7ebd48c
commit aa44ba854f

View File

@ -15,7 +15,7 @@
"! pip install -U torchaudio==0.5.1\n",
"! pip install -U torchvision==0.6.1\n",
"! pip install -U matplotlib==3.2.1\n",
"! pip install -U trains>=0.16.0\n",
"! pip install -U trains>=0.16.1\n",
"! pip install -U pandas==1.0.4\n",
"! pip install -U numpy==1.18.4\n",
"! pip install -U tensorboard==2.2.1"
@ -63,7 +63,7 @@
"outputs": [],
"source": [
"task = Task.init(project_name='Audio Example', task_name='audio classification UrbanSound8K')\n",
"configuration_dict = {'number_of_epochs': 6, 'batch_size': 8, 'dropout': 0.25, 'base_lr': 0.005, \n",
"configuration_dict = {'number_of_epochs': 3, 'batch_size': 8, 'dropout': 0.3, 'base_lr': 0.005, \n",
" 'number_of_mel_filters': 64, 'resample_freq': 22050}\n",
"configuration_dict = task.connect(configuration_dict) # enabling configuration override by trains\n",
"print(configuration_dict) # printing actual configuration (after override in remote mode)"
@ -141,13 +141,13 @@
" melspectogram_db = melspectogram_db[:, :, :fixed_length]\n",
" \n",
" if self.return_audio:\n",
" fixed_length = 2 * self.resample\n",
" fixed_length = 3 * self.resample\n",
" if soundData.numel() < fixed_length:\n",
" soundData = torch.nn.functional.pad(soundData, (0, fixed_length - soundData.numel()))\n",
" soundData = torch.nn.functional.pad(soundData, (0, fixed_length - soundData.numel())).numpy()\n",
" else:\n",
" soundData = soundData[0,:fixed_length].reshape(1,fixed_length)\n",
" soundData = soundData[0,:fixed_length].reshape(1,fixed_length).numpy()\n",
" else:\n",
" soundData = []\n",
" soundData = np.array([])\n",
"\n",
" return soundData, self.resample, melspectogram_db, self.labels[index]\n",
" \n",
@ -164,7 +164,7 @@
"train_loader = torch.utils.data.DataLoader(train_set, batch_size = configuration_dict.get('batch_size', 4), \n",
" shuffle = True, pin_memory=True, num_workers=1)\n",
"test_loader = torch.utils.data.DataLoader(test_set, batch_size = configuration_dict.get('batch_size', 4), \n",
" shuffle = False, pin_memory=True, num_workers=1)\n",
" shuffle = False, pin_memory=False, num_workers=1)\n",
"\n",
"classes = ('air_conditioner', 'car_horn', 'children_playing', 'dog_bark', 'drilling', 'engine_idling', \n",
" 'gun_shot', 'jackhammer', 'siren', 'street_music')"