82 lines
2.3 KiB
Python
82 lines
2.3 KiB
Python
|
"""A set of functions to vizualise the output space of a neural
|
||
|
network.
|
||
|
"""
|
||
|
import random
|
||
|
import sys
|
||
|
|
||
|
import torch
|
||
|
import matplotlib.pyplot as plt
|
||
|
from onnx2pytorch import convert
|
||
|
|
||
|
|
||
|
def sample2d(model, n_sample=1000, frozen_dim=0, frozen_val=0.5):
|
||
|
dim_1 = []
|
||
|
dim_2 = []
|
||
|
ys = []
|
||
|
for i in range(n_sample):
|
||
|
x1 = random.random()
|
||
|
x2 = random.random()
|
||
|
x3 = frozen_val
|
||
|
if frozen_dim == 0:
|
||
|
x = torch.tensor([x3, x1, x2])
|
||
|
elif frozen_dim == 1:
|
||
|
x = torch.tensor([x2, x3, x1])
|
||
|
elif frozen_dim == 2:
|
||
|
x = torch.tensor([x1, x2, x3])
|
||
|
else:
|
||
|
raise ValueError("Error, frozen_dim must be < 2, got {}".format(
|
||
|
frozen_dim))
|
||
|
y = model(x)
|
||
|
dim_1.append(x1)
|
||
|
dim_2.append(x2)
|
||
|
ys.append(y)
|
||
|
colours = list(map(lambda x: 'red' if x > 0 else 'blue', ys))
|
||
|
return dim_1, dim_2, colours
|
||
|
|
||
|
|
||
|
def sample3d(model, n_sample=1000):
|
||
|
dim_1 = []
|
||
|
dim_2 = []
|
||
|
dim_3 = []
|
||
|
ys = []
|
||
|
for i in range(n_sample):
|
||
|
x1 = random.random()
|
||
|
x2 = random.random()
|
||
|
x3 = random.random()
|
||
|
x = torch.tensor([x1, x2, x3])
|
||
|
y = model(x)
|
||
|
dim_1.append(x1)
|
||
|
dim_2.append(x2)
|
||
|
dim_3.append(x3)
|
||
|
ys.append(y)
|
||
|
colours = list(map(lambda x: 'red' if x > 0 else 'blue', ys))
|
||
|
return dim_1, dim_2, dim_3, colours
|
||
|
|
||
|
|
||
|
def plot2d(x1s, x2s, colours):
|
||
|
for j, (x1, x2) in enumerate(zip(x1s, x2s)):
|
||
|
plt.scatter(x1, x2, color=colours[j])
|
||
|
# plt.title("ALPHA=0.5, BETA=0.25, DELTA1=0.3, DELTA2=0.7, EPSILON=0.09")
|
||
|
plt.show()
|
||
|
|
||
|
def plot3d(x1s, x2s, x3s, colours):
|
||
|
fig = plt.figure()
|
||
|
ax = fig.add_subplot(projection='3d')
|
||
|
for j, (x1, x2, x3) in enumerate(zip(x1s, x2s, x3s)):
|
||
|
ax.scatter(x1, x2, x3, color=colours[j])
|
||
|
ax.set_xlabel('distance')
|
||
|
ax.set_ylabel('speed')
|
||
|
ax.set_zlabel('angle')
|
||
|
plt.show()
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
n_sample = int(sys.argv[1])
|
||
|
frozen_dim = int(sys.argv[2])
|
||
|
frozen_val = float(sys.argv[3])
|
||
|
model = convert("network.onnx")
|
||
|
dim_1, dim_2, colours = sample2d(model, n_sample, frozen_dim, frozen_val)
|
||
|
plot2d(dim_1, dim_2, colours)
|
||
|
dim_1, dim_2, dim_3, colours = sample3d(model, n_sample)
|
||
|
plot3d(dim_1, dim_2, dim_3, colours)
|