48 lines
1.4 KiB
Python
48 lines
1.4 KiB
Python
"""Custom datasets. Must implement __len__ and __getitem__.
|
|
The latter should return a dict with data and target vector
|
|
"""
|
|
import os
|
|
import sys
|
|
|
|
from torch.utils.data import Dataset
|
|
import torch
|
|
import numpy as np
|
|
|
|
|
|
# Don't forget to normalize inputs
|
|
class CollisionDataset(Dataset):
|
|
"""Dataset wrapper for DFKI tutorial
|
|
Rely on the data generated by generate_collisions.py
|
|
"""
|
|
|
|
def __init__(self, fpath, train=True, transform=None):
|
|
self.train = train
|
|
self.transform = transform
|
|
self.fpath = fpath
|
|
datalist = []
|
|
if os.path.exists(self.fpath):
|
|
arr = np.load(self.fpath)
|
|
else:
|
|
sys.exit('Error: {} does not exist. Dataset not'.format(fpath) +
|
|
'created')
|
|
# one data out of 10 is going on test data
|
|
self.traindata = arr[np.arange(arr.shape[0]) % 10 != 0, :]
|
|
self.testdata = arr[np.arange(arr.shape[0]) % 10 == 0, :]
|
|
|
|
def __len__(self):
|
|
if self.train:
|
|
return len(self.traindata)
|
|
else:
|
|
return len(self.testdata)
|
|
|
|
def __getitem__(self, idx):
|
|
# Get data and objective
|
|
if self.train:
|
|
data = self.traindata
|
|
else:
|
|
data = self.testdata
|
|
if self.transform:
|
|
data = data.astype(np.float32)
|
|
data = self.transform(data)
|
|
data = data.squeeze()
|
|
return data[idx, :-1], data[idx, -1].squeeze()
|