M2_SETI/IA/seti_master-master/code/dataset/pytorch_dataset.py
2023-01-29 16:56:40 +01:00

48 lines
1.4 KiB
Python

"""Custom datasets. Must implement __len__ and __getitem__.
The latter should return a dict with data and target vector
"""
import os
import sys
from torch.utils.data import Dataset
import torch
import numpy as np
# Don't forget to normalize inputs
class CollisionDataset(Dataset):
"""Dataset wrapper for DFKI tutorial
Rely on the data generated by generate_collisions.py
"""
def __init__(self, fpath, train=True, transform=None):
self.train = train
self.transform = transform
self.fpath = fpath
datalist = []
if os.path.exists(self.fpath):
arr = np.load(self.fpath)
else:
sys.exit('Error: {} does not exist. Dataset not'.format(fpath) +
'created')
# one data out of 10 is going on test data
self.traindata = arr[np.arange(arr.shape[0]) % 10 != 0, :]
self.testdata = arr[np.arange(arr.shape[0]) % 10 == 0, :]
def __len__(self):
if self.train:
return len(self.traindata)
else:
return len(self.testdata)
def __getitem__(self, idx):
# Get data and objective
if self.train:
data = self.traindata
else:
data = self.testdata
if self.transform:
data = data.astype(np.float32)
data = self.transform(data)
data = data.squeeze()
return data[idx, :-1], data[idx, -1].squeeze()