clearml/examples/datasets/urbansounds_dataset_preprocessing.py
2024-01-25 19:41:05 +02:00

164 lines
6.8 KiB
Python

import os.path
from pathlib import Path
import matplotlib as mpl
import numpy as np
from tqdm import tqdm
import torchaudio
import torch
from clearml import Task, Dataset
task = Task.init(project_name="examples/Urbansounds", task_name="preprocessing")
# Let's preprocess the data and create a new ClearML dataset from it, so we can track it around
# The cool thing is, we can easily debug, by using, you guessed it: debug samples! We can log both
# the original sound and its processed mel spectrogram as debug samples, so we can manually check
# if everything went as planned.
class PreProcessor:
def __init__(self):
self.configuration = {"number_of_mel_filters": 64, "resample_freq": 22050}
task.connect(self.configuration)
def preprocess_sample(self, sample, original_sample_freq):
if self.configuration["resample_freq"] > 0:
resample_transform = torchaudio.transforms.Resample(
orig_freq=original_sample_freq,
new_freq=self.configuration["resample_freq"],
)
sample = resample_transform(sample)
# This will convert audio files with two channels into one
sample = torch.mean(sample, dim=0, keepdim=True)
# Convert audio to log-scale Mel spectrogram
melspectrogram_transform = torchaudio.transforms.MelSpectrogram(
sample_rate=self.configuration["resample_freq"],
n_mels=self.configuration["number_of_mel_filters"],
)
melspectrogram = melspectrogram_transform(sample)
melspectogram_db = torchaudio.transforms.AmplitudeToDB()(melspectrogram)
# Make sure all spectrograms are the same size
fixed_length = 3 * (self.configuration["resample_freq"] // 200)
if melspectogram_db.shape[2] < fixed_length:
melspectogram_db = torch.nn.functional.pad(
melspectogram_db, (0, fixed_length - melspectogram_db.shape[2])
)
else:
melspectogram_db = melspectogram_db[:, :, :fixed_length]
return melspectogram_db
class DataSetBuilder:
def __init__(self):
self.configuration = {"dataset_path": "dataset"}
task.connect(self.configuration)
self.original_dataset = Dataset.get(
dataset_project="examples/Urbansounds",
dataset_name="UrbanSounds example",
dataset_tags=["raw"],
alias="Raw Dataset",
)
# This will return the pandas dataframe we added in the previous task
self.metadata = (
Task.get_task(task_id=self.original_dataset._task.id)
.artifacts["metadata"]
.get()
)
# This will download the data and return a local path to the data
self.original_dataset_path = Path(
self.original_dataset.get_mutable_local_copy(
self.configuration["dataset_path"], overwrite=True
)
)
# Prepare a preprocessor that will handle each sample one by one
self.preprocessor = PreProcessor()
# Get ready for the new one
self.preprocessed_dataset = None
def log_dataset_statistics(self):
histogram_data = self.metadata["label"].value_counts()
self.preprocessed_dataset.get_logger().report_table(
title="Raw Dataset Metadata",
series="Raw Dataset Metadata",
table_plot=self.metadata,
)
self.preprocessed_dataset.get_logger().report_histogram(
title="Class distribution",
series="Class distribution",
values=histogram_data,
iteration=0,
xlabels=histogram_data.index.tolist(),
yaxis="Amount of samples",
)
def build_dataset(self):
# Let's create a new dataset that is a child of the original one
# We'll add the preprocessed samples to the original dataset, leading to a new version
# Providing the parent dataset allows us to keep a clear lineage of our data
self.preprocessed_dataset = Dataset.create(
dataset_name="UrbanSounds example",
dataset_project="examples/Urbansounds",
dataset_tags=["preprocessed"],
parent_datasets=[self.original_dataset.id],
)
# loop through the metadata entries and preprocess each sample, then add some of them as debug samples to
# manually double check in the UI that everything has worked (you can watch the spectrogram and listen to the
# audio side by side in the debug sample UI)
for i, (_, data) in tqdm(enumerate(self.metadata.iterrows())):
_, audio_file_path, label = data.tolist()
sample, sample_freq = torchaudio.load(
self.original_dataset_path / audio_file_path, normalize=True
)
spectrogram = self.preprocessor.preprocess_sample(sample, sample_freq)
# Get only the filename and replace the extension, we're saving an image here
new_file_name = os.path.basename(audio_file_path).replace(".wav", ".npy")
# Get the correct folder, basically the original dataset folder + the new filename
spectrogram_path = (
self.original_dataset_path
/ os.path.dirname(audio_file_path)
/ new_file_name
)
# Save the numpy array to disk
np.save(spectrogram_path, spectrogram)
# Log every 10th sample as a debug sample to the UI, so we can manually check it
if i % 10 == 0:
# Convert the numpy array to a viewable JPEG
rgb_image = mpl.colormaps["viridis"](
spectrogram[0, :, :].detach().numpy() * 255
)[:, :, :3]
title = os.path.splitext(os.path.basename(audio_file_path))[0]
# Report the image and the original sound, so they can be viewed side by side
self.preprocessed_dataset.get_logger().report_image(
title=title, series="spectrogram", image=rgb_image
)
self.preprocessed_dataset.get_logger().report_media(
title=title,
series="original_audio",
local_path=self.original_dataset_path / audio_file_path,
)
# The original data path will now also have the spectrograms in its filetree.
# So that's why we add it here to fill up the new dataset with.
self.preprocessed_dataset.add_files(self.original_dataset_path)
# Again add some visualizations to the task
self.log_dataset_statistics()
# We still want the metadata
self.preprocessed_dataset._task.upload_artifact(
name="metadata", artifact_object=self.metadata
)
self.preprocessed_dataset.finalize(auto_upload=True)
if __name__ == "__main__":
datasetbuilder = DataSetBuilder()
datasetbuilder.build_dataset()