mirror of
https://github.com/clearml/clearml
synced 2025-01-31 09:07:00 +00:00
164 lines
6.8 KiB
Python
164 lines
6.8 KiB
Python
import os.path
|
|
from pathlib import Path
|
|
|
|
import matplotlib as mpl
|
|
import numpy as np
|
|
from tqdm import tqdm
|
|
import torchaudio
|
|
import torch
|
|
from clearml import Task, Dataset
|
|
|
|
task = Task.init(project_name="examples/Urbansounds", task_name="preprocessing")
|
|
|
|
# Let's preprocess the data and create a new ClearML dataset from it, so we can track it around
|
|
# The cool thing is, we can easily debug, by using, you guessed it: debug samples! We can log both
|
|
# the original sound and its processed mel spectrogram as debug samples, so we can manually check
|
|
# if everything went as planned.
|
|
|
|
|
|
class PreProcessor:
|
|
def __init__(self):
|
|
self.configuration = {"number_of_mel_filters": 64, "resample_freq": 22050}
|
|
task.connect(self.configuration)
|
|
|
|
def preprocess_sample(self, sample, original_sample_freq):
|
|
if self.configuration["resample_freq"] > 0:
|
|
resample_transform = torchaudio.transforms.Resample(
|
|
orig_freq=original_sample_freq,
|
|
new_freq=self.configuration["resample_freq"],
|
|
)
|
|
sample = resample_transform(sample)
|
|
|
|
# This will convert audio files with two channels into one
|
|
sample = torch.mean(sample, dim=0, keepdim=True)
|
|
|
|
# Convert audio to log-scale Mel spectrogram
|
|
melspectrogram_transform = torchaudio.transforms.MelSpectrogram(
|
|
sample_rate=self.configuration["resample_freq"],
|
|
n_mels=self.configuration["number_of_mel_filters"],
|
|
)
|
|
melspectrogram = melspectrogram_transform(sample)
|
|
melspectogram_db = torchaudio.transforms.AmplitudeToDB()(melspectrogram)
|
|
|
|
# Make sure all spectrograms are the same size
|
|
fixed_length = 3 * (self.configuration["resample_freq"] // 200)
|
|
if melspectogram_db.shape[2] < fixed_length:
|
|
melspectogram_db = torch.nn.functional.pad(
|
|
melspectogram_db, (0, fixed_length - melspectogram_db.shape[2])
|
|
)
|
|
else:
|
|
melspectogram_db = melspectogram_db[:, :, :fixed_length]
|
|
|
|
return melspectogram_db
|
|
|
|
|
|
class DataSetBuilder:
|
|
def __init__(self):
|
|
self.configuration = {"dataset_path": "dataset"}
|
|
task.connect(self.configuration)
|
|
|
|
self.original_dataset = Dataset.get(
|
|
dataset_project="examples/Urbansounds",
|
|
dataset_name="UrbanSounds example",
|
|
dataset_tags=["raw"],
|
|
alias="Raw Dataset",
|
|
)
|
|
# This will return the pandas dataframe we added in the previous task
|
|
self.metadata = (
|
|
Task.get_task(task_id=self.original_dataset._task.id)
|
|
.artifacts["metadata"]
|
|
.get()
|
|
)
|
|
# This will download the data and return a local path to the data
|
|
self.original_dataset_path = Path(
|
|
self.original_dataset.get_mutable_local_copy(
|
|
self.configuration["dataset_path"], overwrite=True
|
|
)
|
|
)
|
|
|
|
# Prepare a preprocessor that will handle each sample one by one
|
|
self.preprocessor = PreProcessor()
|
|
|
|
# Get ready for the new one
|
|
self.preprocessed_dataset = None
|
|
|
|
def log_dataset_statistics(self):
|
|
histogram_data = self.metadata["label"].value_counts()
|
|
self.preprocessed_dataset.get_logger().report_table(
|
|
title="Raw Dataset Metadata",
|
|
series="Raw Dataset Metadata",
|
|
table_plot=self.metadata,
|
|
)
|
|
self.preprocessed_dataset.get_logger().report_histogram(
|
|
title="Class distribution",
|
|
series="Class distribution",
|
|
values=histogram_data,
|
|
iteration=0,
|
|
xlabels=histogram_data.index.tolist(),
|
|
yaxis="Amount of samples",
|
|
)
|
|
|
|
def build_dataset(self):
|
|
# Let's create a new dataset that is a child of the original one
|
|
# We'll add the preprocessed samples to the original dataset, leading to a new version
|
|
# Providing the parent dataset allows us to keep a clear lineage of our data
|
|
self.preprocessed_dataset = Dataset.create(
|
|
dataset_name="UrbanSounds example",
|
|
dataset_project="examples/Urbansounds",
|
|
dataset_tags=["preprocessed"],
|
|
parent_datasets=[self.original_dataset.id],
|
|
)
|
|
|
|
# loop through the metadata entries and preprocess each sample, then add some of them as debug samples to
|
|
# manually double check in the UI that everything has worked (you can watch the spectrogram and listen to the
|
|
# audio side by side in the debug sample UI)
|
|
for i, (_, data) in tqdm(enumerate(self.metadata.iterrows())):
|
|
_, audio_file_path, label = data.tolist()
|
|
sample, sample_freq = torchaudio.load(
|
|
self.original_dataset_path / audio_file_path, normalize=True
|
|
)
|
|
spectrogram = self.preprocessor.preprocess_sample(sample, sample_freq)
|
|
# Get only the filename and replace the extension, we're saving an image here
|
|
new_file_name = os.path.basename(audio_file_path).replace(".wav", ".npy")
|
|
# Get the correct folder, basically the original dataset folder + the new filename
|
|
spectrogram_path = (
|
|
self.original_dataset_path
|
|
/ os.path.dirname(audio_file_path)
|
|
/ new_file_name
|
|
)
|
|
# Save the numpy array to disk
|
|
np.save(spectrogram_path, spectrogram)
|
|
|
|
# Log every 10th sample as a debug sample to the UI, so we can manually check it
|
|
if i % 10 == 0:
|
|
# Convert the numpy array to a viewable JPEG
|
|
rgb_image = mpl.colormaps["viridis"](
|
|
spectrogram[0, :, :].detach().numpy() * 255
|
|
)[:, :, :3]
|
|
title = os.path.splitext(os.path.basename(audio_file_path))[0]
|
|
|
|
# Report the image and the original sound, so they can be viewed side by side
|
|
self.preprocessed_dataset.get_logger().report_image(
|
|
title=title, series="spectrogram", image=rgb_image
|
|
)
|
|
self.preprocessed_dataset.get_logger().report_media(
|
|
title=title,
|
|
series="original_audio",
|
|
local_path=self.original_dataset_path / audio_file_path,
|
|
)
|
|
# The original data path will now also have the spectrograms in its filetree.
|
|
# So that's why we add it here to fill up the new dataset with.
|
|
self.preprocessed_dataset.add_files(self.original_dataset_path)
|
|
# Again add some visualizations to the task
|
|
self.log_dataset_statistics()
|
|
# We still want the metadata
|
|
self.preprocessed_dataset._task.upload_artifact(
|
|
name="metadata", artifact_object=self.metadata
|
|
)
|
|
self.preprocessed_dataset.finalize(auto_upload=True)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
datasetbuilder = DataSetBuilder()
|
|
datasetbuilder.build_dataset()
|