M2_SETI/IA/seti_master-master/code/train.py

453 lines
18 KiB
Python
Raw Normal View History

2023-01-29 16:56:40 +01:00
"""Code used to train and test the neural network used in the tutorial.
Requires files architecture.py, customDatasets.py for custom architectures and datasets handling
"""
import argparse
import os
import sys
from datetime import datetime
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import transforms
from tensorboardX import SummaryWriter
from dataset import pytorch_dataset # local file
def approx_i(x):
"""Approximation of identity function except on a small width"""
return F.relu(1 - F.relu(1 - 1000 * x))
def normalize_tensor(x):
"""Normalize tensor to avoid NaN issues"""
x = x / x.max()
def count_correct(pred, target):
"""Count the number of correct predictions"""
# BCEWithLogitsLoss
if pred.dim() == 1:
if ((pred.item() < 0 and target.item() == 0) or
(pred.item() > 0 and target.item() == 1)):
return 1
else:
return 0
class PerfectNet(nn.Module):
"""Network describing the perfect answer to our problem
(for reference) """
def __init__(self, onnx):
"""
:param onnx: transform for onnx conversion
:type onnx: bool
"""
super(PerfectNet, self).__init__()
self.onnx = onnx
def forward(self, x):
# initial flattening
if self.onnx:
x = x
else:
x = x.view(-1, 3)
# ISAIEH does not support reshape operator yet
x1supdelta = F.relu(x[0] - .3)
x3supbeta = F.relu(x[2] - 0.25)
x2infalha = F.relu(0.5-x[1])
safemove = approx_i(x3supbeta + x2infalha)
conjuct = approx_i(x1supdelta + safemove - 1.5)
output = F.relu(x[0] - 0.7) + conjuct
if self.onnx:
return output
else:
return output.squeeze()
class Net(nn.Module):
"""Net architecture, uses Glorot initialization"""
def __init__(self, onnx):
"""
:param onnx: transform for onnx conversion
:type onnx: bool
"""
super(Net, self).__init__()
self.onnx = onnx
self.l_1 = nn.Linear(3, 10, bias=True)
self.l_2 = nn.Linear(10, 10, bias=True)
self.l_3 = nn.Linear(10, 5, bias=True)
self.l_4 = nn.Linear(5, 2, bias=True)
self.l_5 = nn.Linear(2, 1, bias=True)
nn.init.xavier_uniform_(self.l_1.weight, gain=1)
nn.init.xavier_uniform_(self.l_2.weight, gain=1)
nn.init.xavier_uniform_(self.l_3.weight, gain=1)
nn.init.xavier_uniform_(self.l_4.weight, gain=1)
nn.init.xavier_uniform_(self.l_5.weight, gain=1)
def forward(self, x):
# ISAIEH does not support reshape operator yet
x = self.l_1(x)
x = F.relu(x)
x = self.l_2(x)
x = F.relu(x)
x = self.l_3(x)
x = F.relu(x)
x = self.l_4(x)
x = F.relu(x)
x = self.l_5(x)
return x
def count_parameters(model):
"""Count the number of parameters in a model"""
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def train(args, model, device, train_loader, test_loader,
loss_fn, optimizer, epoch, f, writer):
"""Training function computing gradient descent for a given architecture,
optimizer, loss and epoch.
:param args: arguments of script
:type args: parser.ArgumentParser()
:param model: architecture implementing a forward operation
:type model: class deriving from nn.Module
:param device: device to work on
:type device: torch.device
:param train_loader: training set dataloader
:type train_loader: torch.utils.data.DataLoader
:param test_loader: testing set dataloader
:type test_loader: torch.utils.data.DataLoader
:param loss_fn: cost function of net
:type loss_fn: torch.nn.functionnal
:param optimizer: optimizer to use
:type optimizer: torch.optim
:param epoch: number of epochs
:type epoch: int
:param f: filepath of log file
:type f: string
:param writer: a tensorboard writer
"""
# Compute forward and backward pass and update weights
loss_train = 0
correct_train = 0
accuracy_train = 0
loss_eval = 0
correct_eval = 0
accuracy_eval = 0
for batch_idx, (data, target) in enumerate(train_loader):
model.train()
optimizer.zero_grad()
if args.debug:
print("Example of input: {}".format(data[0]))
print("Size of input: {}".format(data.size()))
data, target = data.to(device), target.to(device)
output = model(data)
# Renormalization for BCE Loss
# t = abs(output - output.min())
# output = t/t.max()
# Renormalization for BCEWithLogitsLossLoss
if args.debug:
print("Example of output: {}".format(output))
print("Corresponding target: {}".format(target))
# normalize_tensor(output)
if args.debug:
print("Size of output: {}".format(output.size()))
print("Size of target: {}".format(target.size()))
print("Example of output: {}".format(output))
print("Corresponding target: {}".format(target))
print('output type:')
print(output.dtype)
print('target type:')
print(target.dtype)
loss = loss_fn(output.squeeze(), target)
loss_train += loss.item()
loss.backward()
optimizer.step()
for i in range(data.size()[0]):
pred = output[i]
objective = target[i]
correct_train += count_correct(pred, objective)
if args.verbose:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
accuracy_train = correct_train/train_loader.dataset.__len__()
loss_train_norm = loss_train / train_loader.dataset.__len__()
writer.add_scalar('loss/train', loss_train / (batch_idx + 1), epoch)
writer.add_scalar('accuracy/train', accuracy_train, epoch)
# Compute test error to compare with train error when logging saving
if epoch % args.log_interval == 0 or epoch == 1:
model.eval()
with torch.no_grad():
for batch_idx, (data, target) in enumerate(test_loader):
data, target = data.to(device), target.to(device)
output = model(data)
if args.debug:
print("Output size of eval: {}".format(output.size()))
print("Target size of eval: {}".format(target.size()))
# Renormalization for BCE Loss
# t = abs(output - output.min())
# output = t/t.max()
# Renormalization for BCEWithLogitsLossLoss
# normalize_tensor(output)
loss = loss_fn(output.squeeze(), target)
loss_eval += loss.item()
for i in range(data.size()[0]):
pred = output[i]
objective = target[i]
correct_eval += count_correct(pred, objective)
if args.debug:
print('Pred: {}'.format(pred))
print('Objective: {}'.format(objective))
accuracy_eval = correct_eval/test_loader.dataset.__len__()
loss_eval_norm = loss_eval / test_loader.dataset.__len__()
str2wr = 'EVAL EPOCH: {epoch};{l_train};{l_eval};{acc};\n'.format(
epoch=epoch, l_train=loss_train/(batch_idx+1),
l_eval=loss_eval/(batch_idx+1), acc=accuracy_eval)
f.write(str2wr)
if args.verbose:
print(str2wr)
writer.add_scalar('loss/eval', loss_eval / (batch_idx + 1), epoch)
writer.add_scalar('accuracy/eval', accuracy_eval, epoch)
def test(args, model, device, test_loader, loss_fn, epoch, f):
"""Testing function computing test performance for a given architecture,
optimizer, loss and epoch.
:param args: arguments of script
:type args: parser.ArgumentParser()
:param model: architecture implementing a forward operation
:type model: class deriving from nn.Module
:param device: device to work on
:type device: torch.device
:param test_loader: testing set dataloader
:type test_loader: torch.utils.data.DataLoader
:param loss_fn: cost function of net
:type loss_fn: torch.nn.functionnal
:param epoch: number of epochs
:type epoch: int
:param f: filepath of log file
:type f: string
"""
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for batch_idx, (data, target) in enumerate(test_loader):
batch_loss = 0
data, target = data.to(device), target.to(device)
output = model(data)
# Renormalization for BCE Loss
# t = abs(output - output.min())
# output = t/t.max()
# Renormalization for BCEWithLogitsLossLoss
# normalize_tensor(output)
batch_loss = loss_fn(output.squeeze(), target)
test_loss += batch_loss
for i in range(data.size()[0]):
sample = data[i]
pred = output[i]
objective = target[i]
correct += count_correct(pred, objective)
f.write(str(sample) + ';' + str(pred.item()) +
';' + str(objective.item()) + '\n')
if batch_idx % args.log_interval == 0:
f.write('Test Epoch: {} '.format(epoch) +
'[{}/'.format(batch_idx*len(data)) +
'{}'.format(len(test_loader.dataset)) +
'({:.0f}%)]'.format(100.*batch_idx/len(test_loader)) +
'\tBatch Loss: {:.6f}\n'.format(batch_loss.item()))
if args.verbose:
print('Test Epoch: {}'.format(epoch) +
'[{}/'.format(batch_idx*len(data)) +
'{}'.format(len(test_loader.dataset)) +
'({:.0f}%)]'.format(100.*batch_idx/len(test_loader)) +
'\tBatch Loss: {:.6f}\n'.format(batch_loss.item()))
average_loss = test_loss/(batch_idx+1)
f.write('#########\nTest results\n')
f.write('Total test loss: {:.4f}\n'.format(test_loss.item()))
f.write('Average loss: {:.4f}, '.format(average_loss.item()) +
'Accuracy: {}/{} ({:.4f}%)\n'.format(
correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)))
print('#########\nTest results\n')
print('Total test loss: {:.4f}\n'.format(test_loss.item()))
print('Average loss: {:.4f}, '.format(average_loss.item()) +
'Accuracy: {}/{} ({:.4f}%)\n'.format(
correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)))
f.close()
def save_onnx(epoch, model, timestmp, device):
"""Save the provided model on the standard .ONNX binary
output format after a single inference pass. Log filepath is in the
arguments of script.
:param epoch: number of epoch
:param model: architecture implementing a forward operation
:type model: class deriving from nn.Module
:param device: device to work on
:type device: torch.device
"""
dummy_input = 42 * torch.ones(1, 1, 3, device=device)
input_names = ["actual_input"]
output_names = ["actual_output"]
file_name = timestmp + '_save_epochs_' + str(epoch) + '.onnx'
if not os.path.exists(os.path.join('checkpoints', 'ONNX')):
os.makedirs(os.path.join('checkpoints', 'ONNX'))
savedir = os.path.join('checkpoints', 'ONNX', file_name)
torch.onnx.export(model, dummy_input, savedir, verbose=True,
input_names=input_names, output_names=output_names)
print('Trained algorithm parameters saved at {}'.format(savedir))
def parse():
"""A function to parse various arguments from the command line.
Returns the relevant arguments.
"""
parser = argparse.ArgumentParser(description="""PyTorch Train script.\n
required parameters: dataset and architecture.""")
parser.add_argument('--batch-size', type=int, default=100, metavar='BS',
help='input batch size for training (default: 100)')
parser.add_argument('--test-batch-size', type=int, default=1000,
metavar='BST',
help='input batch size for testing (default: 1000)')
parser.add_argument('--epochs', type=int, default=50, metavar='E',
help='number of epochs to train (default: 50)')
parser.add_argument('--starting-epoch', type=int, default=0, metavar='SE',
help='Starting epoch if warm-start')
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
help='learning rate (default: 0.01)')
parser.add_argument('--wd', type=float, default=0.0005, metavar='WD',
help='weight decay (default: 0.0005)')
parser.add_argument('--momentum', type=float, default=0.5, metavar='MU',
help='SGD momentum (default: 0.5)')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed (default: 1)')
parser.add_argument('--log-interval', type=int, default=1, metavar='LI',
help="""how many batches to wait
before logging training status""")
parser.add_argument('--save-interval', type=int, default=5, metavar='SI',
help="""how many epochs to wait
before saving weights""")
parser.add_argument('--fpath', type=str,
default='dataset/data_ALPHA=0.5_BETA_=0.25_DELTA1=0.3_DELTA2=0.7_EPSILON=0.09.npy',
metavar='DP',
help='directory to fetch inputs data')
parser.add_argument('--no-train', dest='train', action='store_false',
help='do not train net')
parser.add_argument('--no-test', dest='test', action='store_false',
help='do not test net')
parser.add_argument('--pretrained', type=str, default='', metavar='PRTND',
help='filepath of a pre-trained model')
parser.add_argument('--onnx', dest='onnx', action='store_true',
help='save parameters on ONNX format')
parser.add_argument('--debug', dest='debug', action='store_true',
help='add additional verbosity for debug')
parser.add_argument('--verbose', dest='verbose', action='store_true',
help='Training and testing is more verbose')
parser.set_defaults(train=True, test=True, onnx=True,
debug=False, verbose=False, scheduler=False)
args = parser.parse_args()
return args
def instanciate_dataset(fpath):
"""Perform all the boring stuff with initialization
"""
train_dataset = pytorch_dataset.CollisionDataset(
fpath, train=True, transform=transforms.Compose([transforms.ToTensor()]))
test_dataset = pytorch_dataset.CollisionDataset(
fpath, train=False,
transform=transforms.Compose([transforms.ToTensor()]))
return train_dataset, test_dataset
def main():
args = parse()
use_cuda = not args.no_cuda and torch.cuda.is_available()
# Sets random seed generator
torch.manual_seed(args.seed)
# Wrapper around the actual device used
device = torch.device("cuda" if use_cuda else "cpu")
print('Using device {}'.format(device))
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
train_dataset, test_dataset = instanciate_dataset(args.fpath)
model = Net(onnx=args.onnx).to(device)
# Train/test loaders
# Format train and test sets to feed the network
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=args.batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
test_dataset,
batch_size=args.test_batch_size, shuffle=False, **kwargs)
# Optimizer
optimizer = optim.SGD(model.parameters(),
lr=args.lr, momentum=args.momentum, weight_decay=args.wd)
# Loss function
loss_fn = nn.BCEWithLogitsLoss()
# Pre-trained model loading
if args.pretrained != '':
print("Loading pre-trained model from "+args.pretrained)
model.load_state_dict(torch.load(args.pretrained, map_location=device))
start_epoch = args.starting_epoch
stop_epoch = start_epoch+args.epochs
else:
start_epoch = 0
stop_epoch = args.epochs
timestmp = datetime.now().strftime("%d-%m-%Y:%H:%M:%S")
logs_name = 'max_epoch_' + str(stop_epoch) + \
'batchsize_' + str(args.batch_size) + \
'_' + timestmp
# Main loop
if args.train:
print('Begin training!\n')
print('Model: \n {}'.format(model))
print('Number of parameters: {}'.format(count_parameters(model)))
logs_fname = logs_name + '.train.log'
f = open('logs/'+logs_name, 'w')
f.write('EPOCH;TRAIN LOSS;EVAL LOSS\n')
writer = SummaryWriter('logs/tensorboard/' +
timestmp + logs_name + '.tensorboard')
for epoch in range(start_epoch, stop_epoch+1):
train(args, model, device, train_loader, test_loader,
loss_fn, optimizer, epoch, f, writer)
if epoch % args.save_interval == 0:
savepath = ('checkpoints/epoch_' + str(epoch) +
'_' + logs_name + '.pth')
torch.save(model.state_dict(), savepath)
print('Model saved in '+savepath)
print("Logs in "+str(f.name))
f.close()
if args.test:
print('Begin testing!\n')
logs_fname = logs_name + '.test.log'
f = open('logs/' + logs_fname, 'w')
test(args, model, device, test_loader, loss_fn, args.epochs, f)
print("Logs in "+str(f.name))
if args.onnx:
print('Saving model into ONNX format after one inference!\n')
save_onnx(stop_epoch, model, timestmp, device)
if __name__ == '__main__':
main()