"""A set of functions to vizualise the output space of a neural network. """ import random import sys import torch import matplotlib.pyplot as plt from onnx2pytorch import convert def sample2d(model, n_sample=1000, frozen_dim=0, frozen_val=0.5): dim_1 = [] dim_2 = [] ys = [] for i in range(n_sample): x1 = random.random() x2 = random.random() x3 = frozen_val if frozen_dim == 0: x = torch.tensor([x3, x1, x2]) elif frozen_dim == 1: x = torch.tensor([x2, x3, x1]) elif frozen_dim == 2: x = torch.tensor([x1, x2, x3]) else: raise ValueError("Error, frozen_dim must be < 2, got {}".format( frozen_dim)) y = model(x) dim_1.append(x1) dim_2.append(x2) ys.append(y) colours = list(map(lambda x: 'red' if x > 0 else 'blue', ys)) return dim_1, dim_2, colours def sample3d(model, n_sample=1000): dim_1 = [] dim_2 = [] dim_3 = [] ys = [] for i in range(n_sample): x1 = random.random() x2 = random.random() x3 = random.random() x = torch.tensor([x1, x2, x3]) y = model(x) dim_1.append(x1) dim_2.append(x2) dim_3.append(x3) ys.append(y) colours = list(map(lambda x: 'red' if x > 0 else 'blue', ys)) return dim_1, dim_2, dim_3, colours def plot2d(x1s, x2s, colours): for j, (x1, x2) in enumerate(zip(x1s, x2s)): plt.scatter(x1, x2, color=colours[j]) # plt.title("ALPHA=0.5, BETA=0.25, DELTA1=0.3, DELTA2=0.7, EPSILON=0.09") plt.show() def plot3d(x1s, x2s, x3s, colours): fig = plt.figure() ax = fig.add_subplot(projection='3d') for j, (x1, x2, x3) in enumerate(zip(x1s, x2s, x3s)): ax.scatter(x1, x2, x3, color=colours[j]) ax.set_xlabel('distance') ax.set_ylabel('speed') ax.set_zlabel('angle') plt.show() if __name__ == '__main__': n_sample = int(sys.argv[1]) frozen_dim = int(sys.argv[2]) frozen_val = float(sys.argv[3]) model = convert("network.onnx") dim_1, dim_2, colours = sample2d(model, n_sample, frozen_dim, frozen_val) plot2d(dim_1, dim_2, colours) dim_1, dim_2, dim_3, colours = sample3d(model, n_sample) plot3d(dim_1, dim_2, dim_3, colours)