"""Code used to train and test the neural network used in the tutorial. Requires files architecture.py, customDatasets.py for custom architectures and datasets handling """ import argparse import os import sys from datetime import datetime import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torchvision import transforms from tensorboardX import SummaryWriter from dataset import pytorch_dataset # local file def approx_i(x): """Approximation of identity function except on a small width""" return F.relu(1 - F.relu(1 - 1000 * x)) def normalize_tensor(x): """Normalize tensor to avoid NaN issues""" x = x / x.max() def count_correct(pred, target): """Count the number of correct predictions""" # BCEWithLogitsLoss if pred.dim() == 1: if ((pred.item() < 0 and target.item() == 0) or (pred.item() > 0 and target.item() == 1)): return 1 else: return 0 class PerfectNet(nn.Module): """Network describing the perfect answer to our problem (for reference) """ def __init__(self, onnx): """ :param onnx: transform for onnx conversion :type onnx: bool """ super(PerfectNet, self).__init__() self.onnx = onnx def forward(self, x): # initial flattening if self.onnx: x = x else: x = x.view(-1, 3) # ISAIEH does not support reshape operator yet x1supdelta = F.relu(x[0] - .3) x3supbeta = F.relu(x[2] - 0.25) x2infalha = F.relu(0.5-x[1]) safemove = approx_i(x3supbeta + x2infalha) conjuct = approx_i(x1supdelta + safemove - 1.5) output = F.relu(x[0] - 0.7) + conjuct if self.onnx: return output else: return output.squeeze() class Net(nn.Module): """Net architecture, uses Glorot initialization""" def __init__(self, onnx): """ :param onnx: transform for onnx conversion :type onnx: bool """ super(Net, self).__init__() self.onnx = onnx self.l_1 = nn.Linear(3, 10, bias=True) self.l_2 = nn.Linear(10, 10, bias=True) self.l_3 = nn.Linear(10, 5, bias=True) self.l_4 = nn.Linear(5, 2, bias=True) self.l_5 = nn.Linear(2, 1, bias=True) nn.init.xavier_uniform_(self.l_1.weight, gain=1) nn.init.xavier_uniform_(self.l_2.weight, gain=1) nn.init.xavier_uniform_(self.l_3.weight, gain=1) nn.init.xavier_uniform_(self.l_4.weight, gain=1) nn.init.xavier_uniform_(self.l_5.weight, gain=1) def forward(self, x): # ISAIEH does not support reshape operator yet x = self.l_1(x) x = F.relu(x) x = self.l_2(x) x = F.relu(x) x = self.l_3(x) x = F.relu(x) x = self.l_4(x) x = F.relu(x) x = self.l_5(x) return x def count_parameters(model): """Count the number of parameters in a model""" return sum(p.numel() for p in model.parameters() if p.requires_grad) def train(args, model, device, train_loader, test_loader, loss_fn, optimizer, epoch, f, writer): """Training function computing gradient descent for a given architecture, optimizer, loss and epoch. :param args: arguments of script :type args: parser.ArgumentParser() :param model: architecture implementing a forward operation :type model: class deriving from nn.Module :param device: device to work on :type device: torch.device :param train_loader: training set dataloader :type train_loader: torch.utils.data.DataLoader :param test_loader: testing set dataloader :type test_loader: torch.utils.data.DataLoader :param loss_fn: cost function of net :type loss_fn: torch.nn.functionnal :param optimizer: optimizer to use :type optimizer: torch.optim :param epoch: number of epochs :type epoch: int :param f: filepath of log file :type f: string :param writer: a tensorboard writer """ # Compute forward and backward pass and update weights loss_train = 0 correct_train = 0 accuracy_train = 0 loss_eval = 0 correct_eval = 0 accuracy_eval = 0 for batch_idx, (data, target) in enumerate(train_loader): model.train() optimizer.zero_grad() if args.debug: print("Example of input: {}".format(data[0])) print("Size of input: {}".format(data.size())) data, target = data.to(device), target.to(device) output = model(data) # Renormalization for BCE Loss # t = abs(output - output.min()) # output = t/t.max() # Renormalization for BCEWithLogitsLossLoss if args.debug: print("Example of output: {}".format(output)) print("Corresponding target: {}".format(target)) # normalize_tensor(output) if args.debug: print("Size of output: {}".format(output.size())) print("Size of target: {}".format(target.size())) print("Example of output: {}".format(output)) print("Corresponding target: {}".format(target)) print('output type:') print(output.dtype) print('target type:') print(target.dtype) loss = loss_fn(output.squeeze(), target) loss_train += loss.item() loss.backward() optimizer.step() for i in range(data.size()[0]): pred = output[i] objective = target[i] correct_train += count_correct(pred, objective) if args.verbose: print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( epoch, batch_idx * len(data), len(train_loader.dataset), 100. * batch_idx / len(train_loader), loss.item())) accuracy_train = correct_train/train_loader.dataset.__len__() loss_train_norm = loss_train / train_loader.dataset.__len__() writer.add_scalar('loss/train', loss_train / (batch_idx + 1), epoch) writer.add_scalar('accuracy/train', accuracy_train, epoch) # Compute test error to compare with train error when logging saving if epoch % args.log_interval == 0 or epoch == 1: model.eval() with torch.no_grad(): for batch_idx, (data, target) in enumerate(test_loader): data, target = data.to(device), target.to(device) output = model(data) if args.debug: print("Output size of eval: {}".format(output.size())) print("Target size of eval: {}".format(target.size())) # Renormalization for BCE Loss # t = abs(output - output.min()) # output = t/t.max() # Renormalization for BCEWithLogitsLossLoss # normalize_tensor(output) loss = loss_fn(output.squeeze(), target) loss_eval += loss.item() for i in range(data.size()[0]): pred = output[i] objective = target[i] correct_eval += count_correct(pred, objective) if args.debug: print('Pred: {}'.format(pred)) print('Objective: {}'.format(objective)) accuracy_eval = correct_eval/test_loader.dataset.__len__() loss_eval_norm = loss_eval / test_loader.dataset.__len__() str2wr = 'EVAL EPOCH: {epoch};{l_train};{l_eval};{acc};\n'.format( epoch=epoch, l_train=loss_train/(batch_idx+1), l_eval=loss_eval/(batch_idx+1), acc=accuracy_eval) f.write(str2wr) if args.verbose: print(str2wr) writer.add_scalar('loss/eval', loss_eval / (batch_idx + 1), epoch) writer.add_scalar('accuracy/eval', accuracy_eval, epoch) def test(args, model, device, test_loader, loss_fn, epoch, f): """Testing function computing test performance for a given architecture, optimizer, loss and epoch. :param args: arguments of script :type args: parser.ArgumentParser() :param model: architecture implementing a forward operation :type model: class deriving from nn.Module :param device: device to work on :type device: torch.device :param test_loader: testing set dataloader :type test_loader: torch.utils.data.DataLoader :param loss_fn: cost function of net :type loss_fn: torch.nn.functionnal :param epoch: number of epochs :type epoch: int :param f: filepath of log file :type f: string """ model.eval() test_loss = 0 correct = 0 with torch.no_grad(): for batch_idx, (data, target) in enumerate(test_loader): batch_loss = 0 data, target = data.to(device), target.to(device) output = model(data) # Renormalization for BCE Loss # t = abs(output - output.min()) # output = t/t.max() # Renormalization for BCEWithLogitsLossLoss # normalize_tensor(output) batch_loss = loss_fn(output.squeeze(), target) test_loss += batch_loss for i in range(data.size()[0]): sample = data[i] pred = output[i] objective = target[i] correct += count_correct(pred, objective) f.write(str(sample) + ';' + str(pred.item()) + ';' + str(objective.item()) + '\n') if batch_idx % args.log_interval == 0: f.write('Test Epoch: {} '.format(epoch) + '[{}/'.format(batch_idx*len(data)) + '{}'.format(len(test_loader.dataset)) + '({:.0f}%)]'.format(100.*batch_idx/len(test_loader)) + '\tBatch Loss: {:.6f}\n'.format(batch_loss.item())) if args.verbose: print('Test Epoch: {}'.format(epoch) + '[{}/'.format(batch_idx*len(data)) + '{}'.format(len(test_loader.dataset)) + '({:.0f}%)]'.format(100.*batch_idx/len(test_loader)) + '\tBatch Loss: {:.6f}\n'.format(batch_loss.item())) average_loss = test_loss/(batch_idx+1) f.write('#########\nTest results\n') f.write('Total test loss: {:.4f}\n'.format(test_loss.item())) f.write('Average loss: {:.4f}, '.format(average_loss.item()) + 'Accuracy: {}/{} ({:.4f}%)\n'.format( correct, len(test_loader.dataset), 100. * correct / len(test_loader.dataset))) print('#########\nTest results\n') print('Total test loss: {:.4f}\n'.format(test_loss.item())) print('Average loss: {:.4f}, '.format(average_loss.item()) + 'Accuracy: {}/{} ({:.4f}%)\n'.format( correct, len(test_loader.dataset), 100. * correct / len(test_loader.dataset))) f.close() def save_onnx(epoch, model, timestmp, device): """Save the provided model on the standard .ONNX binary output format after a single inference pass. Log filepath is in the arguments of script. :param epoch: number of epoch :param model: architecture implementing a forward operation :type model: class deriving from nn.Module :param device: device to work on :type device: torch.device """ dummy_input = 42 * torch.ones(1, 1, 3, device=device) input_names = ["actual_input"] output_names = ["actual_output"] file_name = timestmp + '_save_epochs_' + str(epoch) + '.onnx' if not os.path.exists(os.path.join('checkpoints', 'ONNX')): os.makedirs(os.path.join('checkpoints', 'ONNX')) savedir = os.path.join('checkpoints', 'ONNX', file_name) torch.onnx.export(model, dummy_input, savedir, verbose=True, input_names=input_names, output_names=output_names) print('Trained algorithm parameters saved at {}'.format(savedir)) def parse(): """A function to parse various arguments from the command line. Returns the relevant arguments. """ parser = argparse.ArgumentParser(description="""PyTorch Train script.\n required parameters: dataset and architecture.""") parser.add_argument('--batch-size', type=int, default=100, metavar='BS', help='input batch size for training (default: 100)') parser.add_argument('--test-batch-size', type=int, default=1000, metavar='BST', help='input batch size for testing (default: 1000)') parser.add_argument('--epochs', type=int, default=50, metavar='E', help='number of epochs to train (default: 50)') parser.add_argument('--starting-epoch', type=int, default=0, metavar='SE', help='Starting epoch if warm-start') parser.add_argument('--lr', type=float, default=0.01, metavar='LR', help='learning rate (default: 0.01)') parser.add_argument('--wd', type=float, default=0.0005, metavar='WD', help='weight decay (default: 0.0005)') parser.add_argument('--momentum', type=float, default=0.5, metavar='MU', help='SGD momentum (default: 0.5)') parser.add_argument('--no-cuda', action='store_true', default=False, help='disables CUDA training') parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)') parser.add_argument('--log-interval', type=int, default=1, metavar='LI', help="""how many batches to wait before logging training status""") parser.add_argument('--save-interval', type=int, default=5, metavar='SI', help="""how many epochs to wait before saving weights""") parser.add_argument('--fpath', type=str, default='dataset/data_ALPHA=0.5_BETA_=0.25_DELTA1=0.3_DELTA2=0.7_EPSILON=0.09.npy', metavar='DP', help='directory to fetch inputs data') parser.add_argument('--no-train', dest='train', action='store_false', help='do not train net') parser.add_argument('--no-test', dest='test', action='store_false', help='do not test net') parser.add_argument('--pretrained', type=str, default='', metavar='PRTND', help='filepath of a pre-trained model') parser.add_argument('--onnx', dest='onnx', action='store_true', help='save parameters on ONNX format') parser.add_argument('--debug', dest='debug', action='store_true', help='add additional verbosity for debug') parser.add_argument('--verbose', dest='verbose', action='store_true', help='Training and testing is more verbose') parser.set_defaults(train=True, test=True, onnx=True, debug=False, verbose=False, scheduler=False) args = parser.parse_args() return args def instanciate_dataset(fpath): """Perform all the boring stuff with initialization """ train_dataset = pytorch_dataset.CollisionDataset( fpath, train=True, transform=transforms.Compose([transforms.ToTensor()])) test_dataset = pytorch_dataset.CollisionDataset( fpath, train=False, transform=transforms.Compose([transforms.ToTensor()])) return train_dataset, test_dataset def main(): args = parse() use_cuda = not args.no_cuda and torch.cuda.is_available() # Sets random seed generator torch.manual_seed(args.seed) # Wrapper around the actual device used device = torch.device("cuda" if use_cuda else "cpu") print('Using device {}'.format(device)) kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {} train_dataset, test_dataset = instanciate_dataset(args.fpath) model = Net(onnx=args.onnx).to(device) # Train/test loaders # Format train and test sets to feed the network train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size, shuffle=True, **kwargs) test_loader = torch.utils.data.DataLoader( test_dataset, batch_size=args.test_batch_size, shuffle=False, **kwargs) # Optimizer optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.wd) # Loss function loss_fn = nn.BCEWithLogitsLoss() # Pre-trained model loading if args.pretrained != '': print("Loading pre-trained model from "+args.pretrained) model.load_state_dict(torch.load(args.pretrained, map_location=device)) start_epoch = args.starting_epoch stop_epoch = start_epoch+args.epochs else: start_epoch = 0 stop_epoch = args.epochs timestmp = datetime.now().strftime("%d-%m-%Y:%H:%M:%S") logs_name = 'max_epoch_' + str(stop_epoch) + \ 'batchsize_' + str(args.batch_size) + \ '_' + timestmp # Main loop if args.train: print('Begin training!\n') print('Model: \n {}'.format(model)) print('Number of parameters: {}'.format(count_parameters(model))) logs_fname = logs_name + '.train.log' f = open('logs/'+logs_name, 'w') f.write('EPOCH;TRAIN LOSS;EVAL LOSS\n') writer = SummaryWriter('logs/tensorboard/' + timestmp + logs_name + '.tensorboard') for epoch in range(start_epoch, stop_epoch+1): train(args, model, device, train_loader, test_loader, loss_fn, optimizer, epoch, f, writer) if epoch % args.save_interval == 0: savepath = ('checkpoints/epoch_' + str(epoch) + '_' + logs_name + '.pth') torch.save(model.state_dict(), savepath) print('Model saved in '+savepath) print("Logs in "+str(f.name)) f.close() if args.test: print('Begin testing!\n') logs_fname = logs_name + '.test.log' f = open('logs/' + logs_fname, 'w') test(args, model, device, test_loader, loss_fn, args.epochs, f) print("Logs in "+str(f.name)) if args.onnx: print('Saving model into ONNX format after one inference!\n') save_onnx(stop_epoch, model, timestmp, device) if __name__ == '__main__': main()