mirror of
https://github.com/clearml/clearml
synced 2025-01-31 17:17:00 +00:00
f6b9efe54e
* Add hyperdataset examples Co-authored-by: Erez Schnaider <erez@clear.ml>
56 lines
1.6 KiB
Python
56 lines
1.6 KiB
Python
import numpy as np
|
|
import torch.utils.data
|
|
from allegroai import DataView, FrameGroup, Task
|
|
from PIL import Image
|
|
from torch.utils.data import DataLoader
|
|
|
|
|
|
class ExampleDataset(torch.utils.data.Dataset):
|
|
def __init__(self, dv):
|
|
# automatically adjust dataset to balance all queries
|
|
self.frames = dv.to_list()
|
|
|
|
def __getitem__(self, idx):
|
|
frame_group = self.frames[idx] # type: FrameGroup
|
|
img_path = frame_group["000002"].get_local_source()
|
|
img = Image.open(img_path).convert("RGB").resize((256, 256))
|
|
|
|
mask_path = frame_group["000002"].get_local_mask_source()
|
|
mask = Image.open(mask_path).resize((256, 256))
|
|
|
|
return np.array(img), np.array(mask),
|
|
|
|
def __len__(self):
|
|
return len(self.frames)
|
|
|
|
|
|
task = Task.init(project_name='examples', task_name='PyTorch Sample Dataset with Masks')
|
|
|
|
# Create DataView with example query
|
|
dataview = DataView()
|
|
dataview.add_query(dataset_name='sample-dataset-masks', version_name='Current')
|
|
# dataview.add_query(dataset_name='sample-dataset', version_name='Current', roi_query=["aeroplane"])
|
|
|
|
# if we want all files to be downloaded in the background, we can call prefetch
|
|
# dataview.prefetch_files()
|
|
|
|
# create PyTorch Dataset
|
|
dataset = ExampleDataset(dataview)
|
|
|
|
# do your thing here :)
|
|
print('Fake PyTorch stuff below:')
|
|
print('Dataset length', len(dataset))
|
|
|
|
torch.manual_seed(0)
|
|
data_loader = DataLoader(
|
|
dataset,
|
|
batch_size=2,
|
|
num_workers=1,
|
|
pin_memory=True,
|
|
prefetch_factor=2,
|
|
)
|
|
for i, data in enumerate(data_loader):
|
|
print('{}] {}'.format(i, data))
|
|
|
|
print('done')
|