453 lines
18 KiB
Python
453 lines
18 KiB
Python
|
"""Code used to train and test the neural network used in the tutorial.
|
||
|
|
||
|
Requires files architecture.py, customDatasets.py for custom architectures and datasets handling
|
||
|
"""
|
||
|
|
||
|
import argparse
|
||
|
import os
|
||
|
import sys
|
||
|
from datetime import datetime
|
||
|
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
import torch.nn.functional as F
|
||
|
import torch.optim as optim
|
||
|
from torchvision import transforms
|
||
|
from tensorboardX import SummaryWriter
|
||
|
|
||
|
from dataset import pytorch_dataset # local file
|
||
|
|
||
|
|
||
|
def approx_i(x):
|
||
|
"""Approximation of identity function except on a small width"""
|
||
|
return F.relu(1 - F.relu(1 - 1000 * x))
|
||
|
|
||
|
|
||
|
def normalize_tensor(x):
|
||
|
"""Normalize tensor to avoid NaN issues"""
|
||
|
x = x / x.max()
|
||
|
|
||
|
|
||
|
def count_correct(pred, target):
|
||
|
"""Count the number of correct predictions"""
|
||
|
# BCEWithLogitsLoss
|
||
|
if pred.dim() == 1:
|
||
|
if ((pred.item() < 0 and target.item() == 0) or
|
||
|
(pred.item() > 0 and target.item() == 1)):
|
||
|
return 1
|
||
|
else:
|
||
|
return 0
|
||
|
|
||
|
class PerfectNet(nn.Module):
|
||
|
"""Network describing the perfect answer to our problem
|
||
|
(for reference) """
|
||
|
|
||
|
def __init__(self, onnx):
|
||
|
"""
|
||
|
:param onnx: transform for onnx conversion
|
||
|
:type onnx: bool
|
||
|
|
||
|
"""
|
||
|
super(PerfectNet, self).__init__()
|
||
|
self.onnx = onnx
|
||
|
|
||
|
def forward(self, x):
|
||
|
# initial flattening
|
||
|
if self.onnx:
|
||
|
x = x
|
||
|
else:
|
||
|
x = x.view(-1, 3)
|
||
|
# ISAIEH does not support reshape operator yet
|
||
|
x1supdelta = F.relu(x[0] - .3)
|
||
|
x3supbeta = F.relu(x[2] - 0.25)
|
||
|
x2infalha = F.relu(0.5-x[1])
|
||
|
safemove = approx_i(x3supbeta + x2infalha)
|
||
|
conjuct = approx_i(x1supdelta + safemove - 1.5)
|
||
|
output = F.relu(x[0] - 0.7) + conjuct
|
||
|
|
||
|
if self.onnx:
|
||
|
return output
|
||
|
else:
|
||
|
return output.squeeze()
|
||
|
|
||
|
|
||
|
class Net(nn.Module):
|
||
|
"""Net architecture, uses Glorot initialization"""
|
||
|
|
||
|
def __init__(self, onnx):
|
||
|
"""
|
||
|
:param onnx: transform for onnx conversion
|
||
|
:type onnx: bool
|
||
|
|
||
|
"""
|
||
|
super(Net, self).__init__()
|
||
|
self.onnx = onnx
|
||
|
self.l_1 = nn.Linear(3, 10, bias=True)
|
||
|
self.l_2 = nn.Linear(10, 10, bias=True)
|
||
|
self.l_3 = nn.Linear(10, 5, bias=True)
|
||
|
self.l_4 = nn.Linear(5, 2, bias=True)
|
||
|
self.l_5 = nn.Linear(2, 1, bias=True)
|
||
|
nn.init.xavier_uniform_(self.l_1.weight, gain=1)
|
||
|
nn.init.xavier_uniform_(self.l_2.weight, gain=1)
|
||
|
nn.init.xavier_uniform_(self.l_3.weight, gain=1)
|
||
|
nn.init.xavier_uniform_(self.l_4.weight, gain=1)
|
||
|
nn.init.xavier_uniform_(self.l_5.weight, gain=1)
|
||
|
|
||
|
def forward(self, x):
|
||
|
# ISAIEH does not support reshape operator yet
|
||
|
x = self.l_1(x)
|
||
|
x = F.relu(x)
|
||
|
x = self.l_2(x)
|
||
|
x = F.relu(x)
|
||
|
x = self.l_3(x)
|
||
|
x = F.relu(x)
|
||
|
x = self.l_4(x)
|
||
|
x = F.relu(x)
|
||
|
x = self.l_5(x)
|
||
|
return x
|
||
|
|
||
|
|
||
|
def count_parameters(model):
|
||
|
"""Count the number of parameters in a model"""
|
||
|
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||
|
|
||
|
|
||
|
def train(args, model, device, train_loader, test_loader,
|
||
|
loss_fn, optimizer, epoch, f, writer):
|
||
|
"""Training function computing gradient descent for a given architecture,
|
||
|
optimizer, loss and epoch.
|
||
|
|
||
|
:param args: arguments of script
|
||
|
:type args: parser.ArgumentParser()
|
||
|
:param model: architecture implementing a forward operation
|
||
|
:type model: class deriving from nn.Module
|
||
|
:param device: device to work on
|
||
|
:type device: torch.device
|
||
|
:param train_loader: training set dataloader
|
||
|
:type train_loader: torch.utils.data.DataLoader
|
||
|
:param test_loader: testing set dataloader
|
||
|
:type test_loader: torch.utils.data.DataLoader
|
||
|
:param loss_fn: cost function of net
|
||
|
:type loss_fn: torch.nn.functionnal
|
||
|
:param optimizer: optimizer to use
|
||
|
:type optimizer: torch.optim
|
||
|
:param epoch: number of epochs
|
||
|
:type epoch: int
|
||
|
:param f: filepath of log file
|
||
|
:type f: string
|
||
|
:param writer: a tensorboard writer
|
||
|
"""
|
||
|
# Compute forward and backward pass and update weights
|
||
|
loss_train = 0
|
||
|
correct_train = 0
|
||
|
accuracy_train = 0
|
||
|
loss_eval = 0
|
||
|
correct_eval = 0
|
||
|
accuracy_eval = 0
|
||
|
for batch_idx, (data, target) in enumerate(train_loader):
|
||
|
model.train()
|
||
|
optimizer.zero_grad()
|
||
|
if args.debug:
|
||
|
print("Example of input: {}".format(data[0]))
|
||
|
print("Size of input: {}".format(data.size()))
|
||
|
data, target = data.to(device), target.to(device)
|
||
|
output = model(data)
|
||
|
# Renormalization for BCE Loss
|
||
|
# t = abs(output - output.min())
|
||
|
# output = t/t.max()
|
||
|
# Renormalization for BCEWithLogitsLossLoss
|
||
|
if args.debug:
|
||
|
print("Example of output: {}".format(output))
|
||
|
print("Corresponding target: {}".format(target))
|
||
|
# normalize_tensor(output)
|
||
|
if args.debug:
|
||
|
print("Size of output: {}".format(output.size()))
|
||
|
print("Size of target: {}".format(target.size()))
|
||
|
print("Example of output: {}".format(output))
|
||
|
print("Corresponding target: {}".format(target))
|
||
|
print('output type:')
|
||
|
print(output.dtype)
|
||
|
print('target type:')
|
||
|
print(target.dtype)
|
||
|
loss = loss_fn(output.squeeze(), target)
|
||
|
loss_train += loss.item()
|
||
|
loss.backward()
|
||
|
optimizer.step()
|
||
|
for i in range(data.size()[0]):
|
||
|
pred = output[i]
|
||
|
objective = target[i]
|
||
|
correct_train += count_correct(pred, objective)
|
||
|
if args.verbose:
|
||
|
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
|
||
|
epoch, batch_idx * len(data), len(train_loader.dataset),
|
||
|
100. * batch_idx / len(train_loader), loss.item()))
|
||
|
|
||
|
accuracy_train = correct_train/train_loader.dataset.__len__()
|
||
|
loss_train_norm = loss_train / train_loader.dataset.__len__()
|
||
|
writer.add_scalar('loss/train', loss_train / (batch_idx + 1), epoch)
|
||
|
writer.add_scalar('accuracy/train', accuracy_train, epoch)
|
||
|
# Compute test error to compare with train error when logging saving
|
||
|
if epoch % args.log_interval == 0 or epoch == 1:
|
||
|
model.eval()
|
||
|
with torch.no_grad():
|
||
|
for batch_idx, (data, target) in enumerate(test_loader):
|
||
|
data, target = data.to(device), target.to(device)
|
||
|
output = model(data)
|
||
|
if args.debug:
|
||
|
print("Output size of eval: {}".format(output.size()))
|
||
|
print("Target size of eval: {}".format(target.size()))
|
||
|
# Renormalization for BCE Loss
|
||
|
# t = abs(output - output.min())
|
||
|
# output = t/t.max()
|
||
|
# Renormalization for BCEWithLogitsLossLoss
|
||
|
# normalize_tensor(output)
|
||
|
loss = loss_fn(output.squeeze(), target)
|
||
|
loss_eval += loss.item()
|
||
|
for i in range(data.size()[0]):
|
||
|
pred = output[i]
|
||
|
objective = target[i]
|
||
|
correct_eval += count_correct(pred, objective)
|
||
|
if args.debug:
|
||
|
print('Pred: {}'.format(pred))
|
||
|
print('Objective: {}'.format(objective))
|
||
|
accuracy_eval = correct_eval/test_loader.dataset.__len__()
|
||
|
loss_eval_norm = loss_eval / test_loader.dataset.__len__()
|
||
|
|
||
|
str2wr = 'EVAL EPOCH: {epoch};{l_train};{l_eval};{acc};\n'.format(
|
||
|
epoch=epoch, l_train=loss_train/(batch_idx+1),
|
||
|
l_eval=loss_eval/(batch_idx+1), acc=accuracy_eval)
|
||
|
f.write(str2wr)
|
||
|
if args.verbose:
|
||
|
print(str2wr)
|
||
|
writer.add_scalar('loss/eval', loss_eval / (batch_idx + 1), epoch)
|
||
|
writer.add_scalar('accuracy/eval', accuracy_eval, epoch)
|
||
|
|
||
|
|
||
|
def test(args, model, device, test_loader, loss_fn, epoch, f):
|
||
|
"""Testing function computing test performance for a given architecture,
|
||
|
optimizer, loss and epoch.
|
||
|
|
||
|
:param args: arguments of script
|
||
|
:type args: parser.ArgumentParser()
|
||
|
:param model: architecture implementing a forward operation
|
||
|
:type model: class deriving from nn.Module
|
||
|
:param device: device to work on
|
||
|
:type device: torch.device
|
||
|
:param test_loader: testing set dataloader
|
||
|
:type test_loader: torch.utils.data.DataLoader
|
||
|
:param loss_fn: cost function of net
|
||
|
:type loss_fn: torch.nn.functionnal
|
||
|
:param epoch: number of epochs
|
||
|
:type epoch: int
|
||
|
:param f: filepath of log file
|
||
|
:type f: string
|
||
|
"""
|
||
|
model.eval()
|
||
|
test_loss = 0
|
||
|
correct = 0
|
||
|
with torch.no_grad():
|
||
|
for batch_idx, (data, target) in enumerate(test_loader):
|
||
|
batch_loss = 0
|
||
|
data, target = data.to(device), target.to(device)
|
||
|
output = model(data)
|
||
|
# Renormalization for BCE Loss
|
||
|
# t = abs(output - output.min())
|
||
|
# output = t/t.max()
|
||
|
# Renormalization for BCEWithLogitsLossLoss
|
||
|
# normalize_tensor(output)
|
||
|
batch_loss = loss_fn(output.squeeze(), target)
|
||
|
test_loss += batch_loss
|
||
|
for i in range(data.size()[0]):
|
||
|
sample = data[i]
|
||
|
pred = output[i]
|
||
|
objective = target[i]
|
||
|
correct += count_correct(pred, objective)
|
||
|
f.write(str(sample) + ';' + str(pred.item()) +
|
||
|
';' + str(objective.item()) + '\n')
|
||
|
if batch_idx % args.log_interval == 0:
|
||
|
f.write('Test Epoch: {} '.format(epoch) +
|
||
|
'[{}/'.format(batch_idx*len(data)) +
|
||
|
'{}'.format(len(test_loader.dataset)) +
|
||
|
'({:.0f}%)]'.format(100.*batch_idx/len(test_loader)) +
|
||
|
'\tBatch Loss: {:.6f}\n'.format(batch_loss.item()))
|
||
|
if args.verbose:
|
||
|
print('Test Epoch: {}'.format(epoch) +
|
||
|
'[{}/'.format(batch_idx*len(data)) +
|
||
|
'{}'.format(len(test_loader.dataset)) +
|
||
|
'({:.0f}%)]'.format(100.*batch_idx/len(test_loader)) +
|
||
|
'\tBatch Loss: {:.6f}\n'.format(batch_loss.item()))
|
||
|
average_loss = test_loss/(batch_idx+1)
|
||
|
f.write('#########\nTest results\n')
|
||
|
f.write('Total test loss: {:.4f}\n'.format(test_loss.item()))
|
||
|
f.write('Average loss: {:.4f}, '.format(average_loss.item()) +
|
||
|
'Accuracy: {}/{} ({:.4f}%)\n'.format(
|
||
|
correct, len(test_loader.dataset),
|
||
|
100. * correct / len(test_loader.dataset)))
|
||
|
print('#########\nTest results\n')
|
||
|
print('Total test loss: {:.4f}\n'.format(test_loss.item()))
|
||
|
print('Average loss: {:.4f}, '.format(average_loss.item()) +
|
||
|
'Accuracy: {}/{} ({:.4f}%)\n'.format(
|
||
|
correct, len(test_loader.dataset),
|
||
|
100. * correct / len(test_loader.dataset)))
|
||
|
f.close()
|
||
|
|
||
|
|
||
|
def save_onnx(epoch, model, timestmp, device):
|
||
|
"""Save the provided model on the standard .ONNX binary
|
||
|
output format after a single inference pass. Log filepath is in the
|
||
|
arguments of script.
|
||
|
|
||
|
:param epoch: number of epoch
|
||
|
:param model: architecture implementing a forward operation
|
||
|
:type model: class deriving from nn.Module
|
||
|
:param device: device to work on
|
||
|
:type device: torch.device
|
||
|
|
||
|
"""
|
||
|
dummy_input = 42 * torch.ones(1, 1, 3, device=device)
|
||
|
input_names = ["actual_input"]
|
||
|
output_names = ["actual_output"]
|
||
|
file_name = timestmp + '_save_epochs_' + str(epoch) + '.onnx'
|
||
|
if not os.path.exists(os.path.join('checkpoints', 'ONNX')):
|
||
|
os.makedirs(os.path.join('checkpoints', 'ONNX'))
|
||
|
savedir = os.path.join('checkpoints', 'ONNX', file_name)
|
||
|
torch.onnx.export(model, dummy_input, savedir, verbose=True,
|
||
|
input_names=input_names, output_names=output_names)
|
||
|
print('Trained algorithm parameters saved at {}'.format(savedir))
|
||
|
|
||
|
|
||
|
def parse():
|
||
|
"""A function to parse various arguments from the command line.
|
||
|
Returns the relevant arguments.
|
||
|
"""
|
||
|
parser = argparse.ArgumentParser(description="""PyTorch Train script.\n
|
||
|
required parameters: dataset and architecture.""")
|
||
|
parser.add_argument('--batch-size', type=int, default=100, metavar='BS',
|
||
|
help='input batch size for training (default: 100)')
|
||
|
parser.add_argument('--test-batch-size', type=int, default=1000,
|
||
|
metavar='BST',
|
||
|
help='input batch size for testing (default: 1000)')
|
||
|
parser.add_argument('--epochs', type=int, default=50, metavar='E',
|
||
|
help='number of epochs to train (default: 50)')
|
||
|
parser.add_argument('--starting-epoch', type=int, default=0, metavar='SE',
|
||
|
help='Starting epoch if warm-start')
|
||
|
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
|
||
|
help='learning rate (default: 0.01)')
|
||
|
parser.add_argument('--wd', type=float, default=0.0005, metavar='WD',
|
||
|
help='weight decay (default: 0.0005)')
|
||
|
parser.add_argument('--momentum', type=float, default=0.5, metavar='MU',
|
||
|
help='SGD momentum (default: 0.5)')
|
||
|
parser.add_argument('--no-cuda', action='store_true', default=False,
|
||
|
help='disables CUDA training')
|
||
|
parser.add_argument('--seed', type=int, default=1, metavar='S',
|
||
|
help='random seed (default: 1)')
|
||
|
parser.add_argument('--log-interval', type=int, default=1, metavar='LI',
|
||
|
help="""how many batches to wait
|
||
|
before logging training status""")
|
||
|
parser.add_argument('--save-interval', type=int, default=5, metavar='SI',
|
||
|
help="""how many epochs to wait
|
||
|
before saving weights""")
|
||
|
parser.add_argument('--fpath', type=str,
|
||
|
default='dataset/data_ALPHA=0.5_BETA_=0.25_DELTA1=0.3_DELTA2=0.7_EPSILON=0.09.npy',
|
||
|
metavar='DP',
|
||
|
help='directory to fetch inputs data')
|
||
|
parser.add_argument('--no-train', dest='train', action='store_false',
|
||
|
help='do not train net')
|
||
|
parser.add_argument('--no-test', dest='test', action='store_false',
|
||
|
help='do not test net')
|
||
|
parser.add_argument('--pretrained', type=str, default='', metavar='PRTND',
|
||
|
help='filepath of a pre-trained model')
|
||
|
parser.add_argument('--onnx', dest='onnx', action='store_true',
|
||
|
help='save parameters on ONNX format')
|
||
|
parser.add_argument('--debug', dest='debug', action='store_true',
|
||
|
help='add additional verbosity for debug')
|
||
|
parser.add_argument('--verbose', dest='verbose', action='store_true',
|
||
|
help='Training and testing is more verbose')
|
||
|
parser.set_defaults(train=True, test=True, onnx=True,
|
||
|
debug=False, verbose=False, scheduler=False)
|
||
|
args = parser.parse_args()
|
||
|
return args
|
||
|
|
||
|
|
||
|
def instanciate_dataset(fpath):
|
||
|
"""Perform all the boring stuff with initialization
|
||
|
"""
|
||
|
train_dataset = pytorch_dataset.CollisionDataset(
|
||
|
fpath, train=True, transform=transforms.Compose([transforms.ToTensor()]))
|
||
|
test_dataset = pytorch_dataset.CollisionDataset(
|
||
|
fpath, train=False,
|
||
|
transform=transforms.Compose([transforms.ToTensor()]))
|
||
|
return train_dataset, test_dataset
|
||
|
|
||
|
|
||
|
def main():
|
||
|
args = parse()
|
||
|
use_cuda = not args.no_cuda and torch.cuda.is_available()
|
||
|
# Sets random seed generator
|
||
|
torch.manual_seed(args.seed)
|
||
|
# Wrapper around the actual device used
|
||
|
device = torch.device("cuda" if use_cuda else "cpu")
|
||
|
print('Using device {}'.format(device))
|
||
|
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
|
||
|
train_dataset, test_dataset = instanciate_dataset(args.fpath)
|
||
|
model = Net(onnx=args.onnx).to(device)
|
||
|
# Train/test loaders
|
||
|
# Format train and test sets to feed the network
|
||
|
train_loader = torch.utils.data.DataLoader(
|
||
|
train_dataset,
|
||
|
batch_size=args.batch_size, shuffle=True, **kwargs)
|
||
|
test_loader = torch.utils.data.DataLoader(
|
||
|
test_dataset,
|
||
|
batch_size=args.test_batch_size, shuffle=False, **kwargs)
|
||
|
# Optimizer
|
||
|
optimizer = optim.SGD(model.parameters(),
|
||
|
lr=args.lr, momentum=args.momentum, weight_decay=args.wd)
|
||
|
# Loss function
|
||
|
loss_fn = nn.BCEWithLogitsLoss()
|
||
|
# Pre-trained model loading
|
||
|
if args.pretrained != '':
|
||
|
print("Loading pre-trained model from "+args.pretrained)
|
||
|
model.load_state_dict(torch.load(args.pretrained, map_location=device))
|
||
|
start_epoch = args.starting_epoch
|
||
|
stop_epoch = start_epoch+args.epochs
|
||
|
else:
|
||
|
start_epoch = 0
|
||
|
stop_epoch = args.epochs
|
||
|
timestmp = datetime.now().strftime("%d-%m-%Y:%H:%M:%S")
|
||
|
logs_name = 'max_epoch_' + str(stop_epoch) + \
|
||
|
'batchsize_' + str(args.batch_size) + \
|
||
|
'_' + timestmp
|
||
|
# Main loop
|
||
|
if args.train:
|
||
|
print('Begin training!\n')
|
||
|
print('Model: \n {}'.format(model))
|
||
|
print('Number of parameters: {}'.format(count_parameters(model)))
|
||
|
logs_fname = logs_name + '.train.log'
|
||
|
f = open('logs/'+logs_name, 'w')
|
||
|
f.write('EPOCH;TRAIN LOSS;EVAL LOSS\n')
|
||
|
writer = SummaryWriter('logs/tensorboard/' +
|
||
|
timestmp + logs_name + '.tensorboard')
|
||
|
for epoch in range(start_epoch, stop_epoch+1):
|
||
|
train(args, model, device, train_loader, test_loader,
|
||
|
loss_fn, optimizer, epoch, f, writer)
|
||
|
if epoch % args.save_interval == 0:
|
||
|
savepath = ('checkpoints/epoch_' + str(epoch) +
|
||
|
'_' + logs_name + '.pth')
|
||
|
torch.save(model.state_dict(), savepath)
|
||
|
print('Model saved in '+savepath)
|
||
|
print("Logs in "+str(f.name))
|
||
|
f.close()
|
||
|
if args.test:
|
||
|
print('Begin testing!\n')
|
||
|
logs_fname = logs_name + '.test.log'
|
||
|
f = open('logs/' + logs_fname, 'w')
|
||
|
test(args, model, device, test_loader, loss_fn, args.epochs, f)
|
||
|
print("Logs in "+str(f.name))
|
||
|
if args.onnx:
|
||
|
print('Saving model into ONNX format after one inference!\n')
|
||
|
save_onnx(stop_epoch, model, timestmp, device)
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
main()
|