clearml/examples/hyperdatasets/data-ingestion/pytorch_dataset_example.py

52 lines
1.3 KiB
Python
Raw Permalink Normal View History

import numpy as np
import torch.utils.data
from allegroai import DataView, SingleFrame, Task
from PIL import Image
from torch.utils.data import DataLoader
class ExampleDataset(torch.utils.data.Dataset):
def __init__(self, dv):
# automatically adjust dataset to balance all queries
self.frames = dv.to_list()
def __getitem__(self, idx):
frame = self.frames[idx] # type: SingleFrame
img_path = frame.get_local_source()
img = Image.open(img_path).convert("RGB").resize((256, 256))
return np.array(img)
def __len__(self):
return len(self.frames)
task = Task.init(project_name='examples', task_name='PyTorch Sample Dataset')
# Create DataView with example query
dataview = DataView()
dataview.add_query(dataset_name='sample-dataset', version_name='Current')
# if we want all files to be downloaded in the background, we can call prefetch
# dataview.prefetch_files()
# create PyTorch Dataset
dataset = ExampleDataset(dataview)
# do your thing here :)
print('Fake PyTorch stuff below:')
print('Dataset length', len(dataset))
torch.manual_seed(0)
data_loader = DataLoader(
dataset,
batch_size=2,
num_workers=1,
pin_memory=True,
prefetch_factor=2,
)
for i, data in enumerate(data_loader):
print('{}] {}'.format(i, data))
print('done')